lr_scheduler.py 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616
  1. import types
  2. import math
  3. from torch._six import inf
  4. from functools import wraps
  5. import warnings
  6. import weakref
  7. from collections import Counter
  8. from bisect import bisect_right
  9. from .optimizer import Optimizer
  10. EPOCH_DEPRECATION_WARNING = (
  11. "The epoch parameter in `scheduler.step()` was not necessary and is being "
  12. "deprecated where possible. Please use `scheduler.step()` to step the "
  13. "scheduler. During the deprecation, if epoch is different from None, the "
  14. "closed form is used instead of the new chainable form, where available. "
  15. "Please open an issue if you are unable to replicate your use case: "
  16. "https://github.com/pytorch/pytorch/issues/new/choose."
  17. )
  18. class _LRScheduler(object):
  19. def __init__(self, optimizer, last_epoch=-1, verbose=False):
  20. # Attach optimizer
  21. if not isinstance(optimizer, Optimizer):
  22. raise TypeError('{} is not an Optimizer'.format(
  23. type(optimizer).__name__))
  24. self.optimizer = optimizer
  25. # Initialize epoch and base learning rates
  26. if last_epoch == -1:
  27. for group in optimizer.param_groups:
  28. group.setdefault('initial_lr', group['lr'])
  29. else:
  30. for i, group in enumerate(optimizer.param_groups):
  31. if 'initial_lr' not in group:
  32. raise KeyError("param 'initial_lr' is not specified "
  33. "in param_groups[{}] when resuming an optimizer".format(i))
  34. self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
  35. self.last_epoch = last_epoch
  36. # Following https://github.com/pytorch/pytorch/issues/20124
  37. # We would like to ensure that `lr_scheduler.step()` is called after
  38. # `optimizer.step()`
  39. def with_counter(method):
  40. if getattr(method, '_with_counter', False):
  41. # `optimizer.step()` has already been replaced, return.
  42. return method
  43. # Keep a weak reference to the optimizer instance to prevent
  44. # cyclic references.
  45. instance_ref = weakref.ref(method.__self__)
  46. # Get the unbound method for the same purpose.
  47. func = method.__func__
  48. cls = instance_ref().__class__
  49. del method
  50. @wraps(func)
  51. def wrapper(*args, **kwargs):
  52. instance = instance_ref()
  53. instance._step_count += 1
  54. wrapped = func.__get__(instance, cls)
  55. return wrapped(*args, **kwargs)
  56. # Note that the returned function here is no longer a bound method,
  57. # so attributes like `__func__` and `__self__` no longer exist.
  58. wrapper._with_counter = True
  59. return wrapper
  60. self.optimizer.step = with_counter(self.optimizer.step)
  61. self.optimizer._step_count = 0
  62. self._step_count = 0
  63. self.verbose = verbose
  64. self.step()
  65. def state_dict(self):
  66. """Returns the state of the scheduler as a :class:`dict`.
  67. It contains an entry for every variable in self.__dict__ which
  68. is not the optimizer.
  69. """
  70. return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
  71. def load_state_dict(self, state_dict):
  72. """Loads the schedulers state.
  73. Args:
  74. state_dict (dict): scheduler state. Should be an object returned
  75. from a call to :meth:`state_dict`.
  76. """
  77. self.__dict__.update(state_dict)
  78. def get_last_lr(self):
  79. """ Return last computed learning rate by current scheduler.
  80. """
  81. return self._last_lr
  82. def get_lr(self):
  83. # Compute learning rate using chainable form of the scheduler
  84. raise NotImplementedError
  85. def print_lr(self, is_verbose, group, lr, epoch=None):
  86. """Display the current learning rate.
  87. """
  88. if is_verbose:
  89. if epoch is None:
  90. print('Adjusting learning rate'
  91. ' of group {} to {:.4e}.'.format(group, lr))
  92. else:
  93. epoch_str = ("%.2f" if isinstance(epoch, float) else
  94. "%.5d") % epoch
  95. print('Epoch {}: adjusting learning rate'
  96. ' of group {} to {:.4e}.'.format(epoch_str, group, lr))
  97. def step(self, epoch=None):
  98. # Raise a warning if old pattern is detected
  99. # https://github.com/pytorch/pytorch/issues/20124
  100. if self._step_count == 1:
  101. if not hasattr(self.optimizer.step, "_with_counter"):
  102. warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
  103. "initialization. Please, make sure to call `optimizer.step()` before "
  104. "`lr_scheduler.step()`. See more details at "
  105. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
  106. # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
  107. elif self.optimizer._step_count < 1:
  108. warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
  109. "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
  110. "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
  111. "will result in PyTorch skipping the first value of the learning rate schedule. "
  112. "See more details at "
  113. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
  114. self._step_count += 1
  115. class _enable_get_lr_call:
  116. def __init__(self, o):
  117. self.o = o
  118. def __enter__(self):
  119. self.o._get_lr_called_within_step = True
  120. return self
  121. def __exit__(self, type, value, traceback):
  122. self.o._get_lr_called_within_step = False
  123. with _enable_get_lr_call(self):
  124. if epoch is None:
  125. self.last_epoch += 1
  126. values = self.get_lr()
  127. else:
  128. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  129. self.last_epoch = epoch
  130. if hasattr(self, "_get_closed_form_lr"):
  131. values = self._get_closed_form_lr()
  132. else:
  133. values = self.get_lr()
  134. for i, data in enumerate(zip(self.optimizer.param_groups, values)):
  135. param_group, lr = data
  136. param_group['lr'] = lr
  137. self.print_lr(self.verbose, i, lr, epoch)
  138. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  139. class LambdaLR(_LRScheduler):
  140. """Sets the learning rate of each parameter group to the initial lr
  141. times a given function. When last_epoch=-1, sets initial lr as lr.
  142. Args:
  143. optimizer (Optimizer): Wrapped optimizer.
  144. lr_lambda (function or list): A function which computes a multiplicative
  145. factor given an integer parameter epoch, or a list of such
  146. functions, one for each group in optimizer.param_groups.
  147. last_epoch (int): The index of last epoch. Default: -1.
  148. verbose (bool): If ``True``, prints a message to stdout for
  149. each update. Default: ``False``.
  150. Example:
  151. >>> # Assuming optimizer has two groups.
  152. >>> lambda1 = lambda epoch: epoch // 30
  153. >>> lambda2 = lambda epoch: 0.95 ** epoch
  154. >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
  155. >>> for epoch in range(100):
  156. >>> train(...)
  157. >>> validate(...)
  158. >>> scheduler.step()
  159. """
  160. def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
  161. self.optimizer = optimizer
  162. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  163. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  164. else:
  165. if len(lr_lambda) != len(optimizer.param_groups):
  166. raise ValueError("Expected {} lr_lambdas, but got {}".format(
  167. len(optimizer.param_groups), len(lr_lambda)))
  168. self.lr_lambdas = list(lr_lambda)
  169. super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)
  170. def state_dict(self):
  171. """Returns the state of the scheduler as a :class:`dict`.
  172. It contains an entry for every variable in self.__dict__ which
  173. is not the optimizer.
  174. The learning rate lambda functions will only be saved if they are callable objects
  175. and not if they are functions or lambdas.
  176. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  177. """
  178. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
  179. state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
  180. for idx, fn in enumerate(self.lr_lambdas):
  181. if not isinstance(fn, types.FunctionType):
  182. state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
  183. return state_dict
  184. def load_state_dict(self, state_dict):
  185. """Loads the schedulers state.
  186. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  187. Args:
  188. state_dict (dict): scheduler state. Should be an object returned
  189. from a call to :meth:`state_dict`.
  190. """
  191. lr_lambdas = state_dict.pop('lr_lambdas')
  192. self.__dict__.update(state_dict)
  193. # Restore state_dict keys in order to prevent side effects
  194. # https://github.com/pytorch/pytorch/issues/32756
  195. state_dict['lr_lambdas'] = lr_lambdas
  196. for idx, fn in enumerate(lr_lambdas):
  197. if fn is not None:
  198. self.lr_lambdas[idx].__dict__.update(fn)
  199. def get_lr(self):
  200. if not self._get_lr_called_within_step:
  201. warnings.warn("To get the last learning rate computed by the scheduler, "
  202. "please use `get_last_lr()`.")
  203. return [base_lr * lmbda(self.last_epoch)
  204. for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
  205. class MultiplicativeLR(_LRScheduler):
  206. """Multiply the learning rate of each parameter group by the factor given
  207. in the specified function. When last_epoch=-1, sets initial lr as lr.
  208. Args:
  209. optimizer (Optimizer): Wrapped optimizer.
  210. lr_lambda (function or list): A function which computes a multiplicative
  211. factor given an integer parameter epoch, or a list of such
  212. functions, one for each group in optimizer.param_groups.
  213. last_epoch (int): The index of last epoch. Default: -1.
  214. verbose (bool): If ``True``, prints a message to stdout for
  215. each update. Default: ``False``.
  216. Example:
  217. >>> lmbda = lambda epoch: 0.95
  218. >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
  219. >>> for epoch in range(100):
  220. >>> train(...)
  221. >>> validate(...)
  222. >>> scheduler.step()
  223. """
  224. def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
  225. self.optimizer = optimizer
  226. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  227. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  228. else:
  229. if len(lr_lambda) != len(optimizer.param_groups):
  230. raise ValueError("Expected {} lr_lambdas, but got {}".format(
  231. len(optimizer.param_groups), len(lr_lambda)))
  232. self.lr_lambdas = list(lr_lambda)
  233. super(MultiplicativeLR, self).__init__(optimizer, last_epoch, verbose)
  234. def state_dict(self):
  235. """Returns the state of the scheduler as a :class:`dict`.
  236. It contains an entry for every variable in self.__dict__ which
  237. is not the optimizer.
  238. The learning rate lambda functions will only be saved if they are callable objects
  239. and not if they are functions or lambdas.
  240. """
  241. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
  242. state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
  243. for idx, fn in enumerate(self.lr_lambdas):
  244. if not isinstance(fn, types.FunctionType):
  245. state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
  246. return state_dict
  247. def load_state_dict(self, state_dict):
  248. """Loads the schedulers state.
  249. Args:
  250. state_dict (dict): scheduler state. Should be an object returned
  251. from a call to :meth:`state_dict`.
  252. """
  253. lr_lambdas = state_dict.pop('lr_lambdas')
  254. self.__dict__.update(state_dict)
  255. # Restore state_dict keys in order to prevent side effects
  256. # https://github.com/pytorch/pytorch/issues/32756
  257. state_dict['lr_lambdas'] = lr_lambdas
  258. for idx, fn in enumerate(lr_lambdas):
  259. if fn is not None:
  260. self.lr_lambdas[idx].__dict__.update(fn)
  261. def get_lr(self):
  262. if not self._get_lr_called_within_step:
  263. warnings.warn("To get the last learning rate computed by the scheduler, "
  264. "please use `get_last_lr()`.", UserWarning)
  265. if self.last_epoch > 0:
  266. return [group['lr'] * lmbda(self.last_epoch)
  267. for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
  268. else:
  269. return [group['lr'] for group in self.optimizer.param_groups]
  270. class StepLR(_LRScheduler):
  271. """Decays the learning rate of each parameter group by gamma every
  272. step_size epochs. Notice that such decay can happen simultaneously with
  273. other changes to the learning rate from outside this scheduler. When
  274. last_epoch=-1, sets initial lr as lr.
  275. Args:
  276. optimizer (Optimizer): Wrapped optimizer.
  277. step_size (int): Period of learning rate decay.
  278. gamma (float): Multiplicative factor of learning rate decay.
  279. Default: 0.1.
  280. last_epoch (int): The index of last epoch. Default: -1.
  281. verbose (bool): If ``True``, prints a message to stdout for
  282. each update. Default: ``False``.
  283. Example:
  284. >>> # Assuming optimizer uses lr = 0.05 for all groups
  285. >>> # lr = 0.05 if epoch < 30
  286. >>> # lr = 0.005 if 30 <= epoch < 60
  287. >>> # lr = 0.0005 if 60 <= epoch < 90
  288. >>> # ...
  289. >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
  290. >>> for epoch in range(100):
  291. >>> train(...)
  292. >>> validate(...)
  293. >>> scheduler.step()
  294. """
  295. def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
  296. self.step_size = step_size
  297. self.gamma = gamma
  298. super(StepLR, self).__init__(optimizer, last_epoch, verbose)
  299. def get_lr(self):
  300. if not self._get_lr_called_within_step:
  301. warnings.warn("To get the last learning rate computed by the scheduler, "
  302. "please use `get_last_lr()`.", UserWarning)
  303. if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
  304. return [group['lr'] for group in self.optimizer.param_groups]
  305. return [group['lr'] * self.gamma
  306. for group in self.optimizer.param_groups]
  307. def _get_closed_form_lr(self):
  308. return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
  309. for base_lr in self.base_lrs]
  310. class MultiStepLR(_LRScheduler):
  311. """Decays the learning rate of each parameter group by gamma once the
  312. number of epoch reaches one of the milestones. Notice that such decay can
  313. happen simultaneously with other changes to the learning rate from outside
  314. this scheduler. When last_epoch=-1, sets initial lr as lr.
  315. Args:
  316. optimizer (Optimizer): Wrapped optimizer.
  317. milestones (list): List of epoch indices. Must be increasing.
  318. gamma (float): Multiplicative factor of learning rate decay.
  319. Default: 0.1.
  320. last_epoch (int): The index of last epoch. Default: -1.
  321. verbose (bool): If ``True``, prints a message to stdout for
  322. each update. Default: ``False``.
  323. Example:
  324. >>> # Assuming optimizer uses lr = 0.05 for all groups
  325. >>> # lr = 0.05 if epoch < 30
  326. >>> # lr = 0.005 if 30 <= epoch < 80
  327. >>> # lr = 0.0005 if epoch >= 80
  328. >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
  329. >>> for epoch in range(100):
  330. >>> train(...)
  331. >>> validate(...)
  332. >>> scheduler.step()
  333. """
  334. def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
  335. self.milestones = Counter(milestones)
  336. self.gamma = gamma
  337. super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose)
  338. def get_lr(self):
  339. if not self._get_lr_called_within_step:
  340. warnings.warn("To get the last learning rate computed by the scheduler, "
  341. "please use `get_last_lr()`.", UserWarning)
  342. if self.last_epoch not in self.milestones:
  343. return [group['lr'] for group in self.optimizer.param_groups]
  344. return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
  345. for group in self.optimizer.param_groups]
  346. def _get_closed_form_lr(self):
  347. milestones = list(sorted(self.milestones.elements()))
  348. return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
  349. for base_lr in self.base_lrs]
  350. class ConstantLR(_LRScheduler):
  351. """Decays the learning rate of each parameter group by a small constant factor until the
  352. number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
  353. happen simultaneously with other changes to the learning rate from outside this scheduler.
  354. When last_epoch=-1, sets initial lr as lr.
  355. Args:
  356. optimizer (Optimizer): Wrapped optimizer.
  357. factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
  358. total_iters (int): The number of steps that the scheduler decays the learning rate.
  359. Default: 5.
  360. last_epoch (int): The index of the last epoch. Default: -1.
  361. verbose (bool): If ``True``, prints a message to stdout for
  362. each update. Default: ``False``.
  363. Example:
  364. >>> # Assuming optimizer uses lr = 0.05 for all groups
  365. >>> # lr = 0.025 if epoch == 0
  366. >>> # lr = 0.025 if epoch == 1
  367. >>> # lr = 0.025 if epoch == 2
  368. >>> # lr = 0.025 if epoch == 3
  369. >>> # lr = 0.05 if epoch >= 4
  370. >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
  371. >>> for epoch in range(100):
  372. >>> train(...)
  373. >>> validate(...)
  374. >>> scheduler.step()
  375. """
  376. def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
  377. if factor > 1.0 or factor < 0:
  378. raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
  379. self.factor = factor
  380. self.total_iters = total_iters
  381. super(ConstantLR, self).__init__(optimizer, last_epoch, verbose)
  382. def get_lr(self):
  383. if not self._get_lr_called_within_step:
  384. warnings.warn("To get the last learning rate computed by the scheduler, "
  385. "please use `get_last_lr()`.", UserWarning)
  386. if self.last_epoch == 0:
  387. return [group['lr'] * self.factor for group in self.optimizer.param_groups]
  388. if (self.last_epoch > self.total_iters or
  389. (self.last_epoch != self.total_iters)):
  390. return [group['lr'] for group in self.optimizer.param_groups]
  391. if (self.last_epoch == self.total_iters):
  392. return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
  393. def _get_closed_form_lr(self):
  394. return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
  395. for base_lr in self.base_lrs]
  396. class LinearLR(_LRScheduler):
  397. """Decays the learning rate of each parameter group by linearly changing small
  398. multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
  399. Notice that such decay can happen simultaneously with other changes to the learning rate
  400. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  401. Args:
  402. optimizer (Optimizer): Wrapped optimizer.
  403. start_factor (float): The number we multiply learning rate in the first epoch.
  404. The multiplication factor changes towards end_factor in the following epochs.
  405. Default: 1./3.
  406. end_factor (float): The number we multiply learning rate at the end of linear changing
  407. process. Default: 1.0.
  408. total_iters (int): The number of iterations that multiplicative factor reaches to 1.
  409. Default: 5.
  410. last_epoch (int): The index of the last epoch. Default: -1.
  411. verbose (bool): If ``True``, prints a message to stdout for
  412. each update. Default: ``False``.
  413. Example:
  414. >>> # Assuming optimizer uses lr = 0.05 for all groups
  415. >>> # lr = 0.025 if epoch == 0
  416. >>> # lr = 0.03125 if epoch == 1
  417. >>> # lr = 0.0375 if epoch == 2
  418. >>> # lr = 0.04375 if epoch == 3
  419. >>> # lr = 0.05 if epoch >= 4
  420. >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
  421. >>> for epoch in range(100):
  422. >>> train(...)
  423. >>> validate(...)
  424. >>> scheduler.step()
  425. """
  426. def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
  427. verbose=False):
  428. if start_factor > 1.0 or start_factor < 0:
  429. raise ValueError('Starting multiplicative factor expected to be between 0 and 1.')
  430. if end_factor > 1.0 or end_factor < 0:
  431. raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
  432. self.start_factor = start_factor
  433. self.end_factor = end_factor
  434. self.total_iters = total_iters
  435. super(LinearLR, self).__init__(optimizer, last_epoch, verbose)
  436. def get_lr(self):
  437. if not self._get_lr_called_within_step:
  438. warnings.warn("To get the last learning rate computed by the scheduler, "
  439. "please use `get_last_lr()`.", UserWarning)
  440. if self.last_epoch == 0:
  441. return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
  442. if (self.last_epoch > self.total_iters):
  443. return [group['lr'] for group in self.optimizer.param_groups]
  444. return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
  445. (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
  446. for group in self.optimizer.param_groups]
  447. def _get_closed_form_lr(self):
  448. return [base_lr * (self.start_factor +
  449. (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
  450. for base_lr in self.base_lrs]
  451. class ExponentialLR(_LRScheduler):
  452. """Decays the learning rate of each parameter group by gamma every epoch.
  453. When last_epoch=-1, sets initial lr as lr.
  454. Args:
  455. optimizer (Optimizer): Wrapped optimizer.
  456. gamma (float): Multiplicative factor of learning rate decay.
  457. last_epoch (int): The index of last epoch. Default: -1.
  458. verbose (bool): If ``True``, prints a message to stdout for
  459. each update. Default: ``False``.
  460. """
  461. def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
  462. self.gamma = gamma
  463. super(ExponentialLR, self).__init__(optimizer, last_epoch, verbose)
  464. def get_lr(self):
  465. if not self._get_lr_called_within_step:
  466. warnings.warn("To get the last learning rate computed by the scheduler, "
  467. "please use `get_last_lr()`.", UserWarning)
  468. if self.last_epoch == 0:
  469. return [group['lr'] for group in self.optimizer.param_groups]
  470. return [group['lr'] * self.gamma
  471. for group in self.optimizer.param_groups]
  472. def _get_closed_form_lr(self):
  473. return [base_lr * self.gamma ** self.last_epoch
  474. for base_lr in self.base_lrs]
  475. class SequentialLR(_LRScheduler):
  476. """Receives the list of schedulers that is expected to be called sequentially during
  477. optimization process and milestone points that provides exact intervals to reflect
  478. which scheduler is supposed to be called at a given epoch.
  479. Args:
  480. optimizer (Optimizer): Wrapped optimizer.
  481. schedulers (list): List of chained schedulers.
  482. milestones (list): List of integers that reflects milestone points.
  483. last_epoch (int): The index of last epoch. Default: -1.
  484. verbose (bool): Does nothing.
  485. Example:
  486. >>> # Assuming optimizer uses lr = 1. for all groups
  487. >>> # lr = 0.1 if epoch == 0
  488. >>> # lr = 0.1 if epoch == 1
  489. >>> # lr = 0.9 if epoch == 2
  490. >>> # lr = 0.81 if epoch == 3
  491. >>> # lr = 0.729 if epoch == 4
  492. >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
  493. >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
  494. >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
  495. >>> for epoch in range(100):
  496. >>> train(...)
  497. >>> validate(...)
  498. >>> scheduler.step()
  499. """
  500. def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
  501. for scheduler_idx in range(len(schedulers)):
  502. if schedulers[scheduler_idx].optimizer != optimizer:
  503. raise ValueError(
  504. "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
  505. f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
  506. )
  507. if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
  508. raise ValueError(
  509. "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
  510. f"got schedulers at index {0} and {scheduler_idx} to be different."
  511. )
  512. if (len(milestones) != len(schedulers) - 1):
  513. raise ValueError(
  514. "Sequential Schedulers expects number of schedulers provided to be one more "
  515. "than the number of milestone points, but got number of schedulers {} and the "
  516. "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
  517. )
  518. self._schedulers = schedulers
  519. self._milestones = milestones
  520. self.last_epoch = last_epoch + 1
  521. self.optimizer = optimizer
  522. self._last_lr = schedulers[0].get_last_lr()
  523. def step(self):
  524. self.last_epoch += 1
  525. idx = bisect_right(self._milestones, self.last_epoch)
  526. if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
  527. self._schedulers[idx].step(0)
  528. else:
  529. self._schedulers[idx].step()
  530. self._last_lr = self._schedulers[idx].get_last_lr()
  531. def state_dict(self):
  532. """Returns the state of the scheduler as a :class:`dict`.
  533. It contains an entry for every variable in self.__dict__ which
  534. is not the optimizer.
  535. The wrapped scheduler states will also be saved.
  536. """
  537. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
  538. state_dict['_schedulers'] = [None] * len(self._schedulers)
  539. for idx, s in enumerate(self._schedulers):
  540. state_dict['_schedulers'][idx] = s.state_dict()
  541. return state_dict
  542. def load_state_dict(self, state_dict):
  543. """Loads the schedulers state.
  544. Args:
  545. state_dict (dict): scheduler state. Should be an object returned
  546. from a call to :meth:`state_dict`.
  547. """
  548. _schedulers = state_dict.pop('_schedulers')
  549. self.__dict__.update(state_dict)
  550. # Restore state_dict keys in order to prevent side effects
  551. # https://github.com/pytorch/pytorch/issues/32756
  552. state_dict['_schedulers'] = _schedulers
  553. for idx, s in enumerate(_schedulers):
  554. self._schedulers[idx].load_state_dict(s)
  555. class CosineAnnealingLR(_LRScheduler):
  556. r"""Set the learning rate of each parameter group using a cosine annealing
  557. schedule, where :math:`\eta_{max}` is set to the initial lr and
  558. :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
  559. .. math::
  560. \begin{aligned}
  561. \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
  562. + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
  563. & T_{cur} \neq (2k+1)T_{max}; \\
  564. \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
  565. \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
  566. & T_{cur} = (2k+1)T_{max}.
  567. \end{aligned}
  568. When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
  569. is defined recursively, the learning rate can be simultaneously modified
  570. outside this scheduler by other operators. If the learning rate is set
  571. solely by this scheduler, the learning rate at each step becomes:
  572. .. math::
  573. \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
  574. \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
  575. It has been proposed in
  576. `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
  577. implements the cosine annealing part of SGDR, and not the restarts.
  578. Args:
  579. optimizer (Optimizer): Wrapped optimizer.
  580. T_max (int): Maximum number of iterations.
  581. eta_min (float): Minimum learning rate. Default: 0.
  582. last_epoch (int): The index of last epoch. Default: -1.
  583. verbose (bool): If ``True``, prints a message to stdout for
  584. each update. Default: ``False``.
  585. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  586. https://arxiv.org/abs/1608.03983
  587. """
  588. def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
  589. self.T_max = T_max
  590. self.eta_min = eta_min
  591. super(CosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)
  592. def get_lr(self):
  593. if not self._get_lr_called_within_step:
  594. warnings.warn("To get the last learning rate computed by the scheduler, "
  595. "please use `get_last_lr()`.", UserWarning)
  596. if self.last_epoch == 0:
  597. return [group['lr'] for group in self.optimizer.param_groups]
  598. elif self._step_count == 1 and self.last_epoch > 0:
  599. return [self.eta_min + (base_lr - self.eta_min) *
  600. (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
  601. for base_lr, group in
  602. zip(self.base_lrs, self.optimizer.param_groups)]
  603. elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
  604. return [group['lr'] + (base_lr - self.eta_min) *
  605. (1 - math.cos(math.pi / self.T_max)) / 2
  606. for base_lr, group in
  607. zip(self.base_lrs, self.optimizer.param_groups)]
  608. return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
  609. (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
  610. (group['lr'] - self.eta_min) + self.eta_min
  611. for group in self.optimizer.param_groups]
  612. def _get_closed_form_lr(self):
  613. return [self.eta_min + (base_lr - self.eta_min) *
  614. (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
  615. for base_lr in self.base_lrs]
  616. class ChainedScheduler(_LRScheduler):
  617. """Chains list of learning rate schedulers. It takes a list of chainable learning
  618. rate schedulers and performs consecutive step() functions belong to them by just
  619. one call.
  620. Args:
  621. schedulers (list): List of chained schedulers.
  622. Example:
  623. >>> # Assuming optimizer uses lr = 1. for all groups
  624. >>> # lr = 0.09 if epoch == 0
  625. >>> # lr = 0.081 if epoch == 1
  626. >>> # lr = 0.729 if epoch == 2
  627. >>> # lr = 0.6561 if epoch == 3
  628. >>> # lr = 0.59049 if epoch >= 4
  629. >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
  630. >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
  631. >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
  632. >>> for epoch in range(100):
  633. >>> train(...)
  634. >>> validate(...)
  635. >>> scheduler.step()
  636. """
  637. def __init__(self, schedulers):
  638. for scheduler_idx in range(1, len(schedulers)):
  639. if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
  640. raise ValueError(
  641. "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
  642. "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
  643. )
  644. self._schedulers = list(schedulers)
  645. self.optimizer = schedulers[0].optimizer
  646. self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
  647. def step(self):
  648. for scheduler in self._schedulers:
  649. scheduler.step()
  650. self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
  651. def state_dict(self):
  652. """Returns the state of the scheduler as a :class:`dict`.
  653. It contains an entry for every variable in self.__dict__ which
  654. is not the optimizer.
  655. The wrapped scheduler states will also be saved.
  656. """
  657. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
  658. state_dict['_schedulers'] = [None] * len(self._schedulers)
  659. for idx, s in enumerate(self._schedulers):
  660. state_dict['_schedulers'][idx] = s.state_dict()
  661. return state_dict
  662. def load_state_dict(self, state_dict):
  663. """Loads the schedulers state.
  664. Args:
  665. state_dict (dict): scheduler state. Should be an object returned
  666. from a call to :meth:`state_dict`.
  667. """
  668. _schedulers = state_dict.pop('_schedulers')
  669. self.__dict__.update(state_dict)
  670. # Restore state_dict keys in order to prevent side effects
  671. # https://github.com/pytorch/pytorch/issues/32756
  672. state_dict['_schedulers'] = _schedulers
  673. for idx, s in enumerate(_schedulers):
  674. self._schedulers[idx].load_state_dict(s)
  675. class ReduceLROnPlateau(object):
  676. """Reduce learning rate when a metric has stopped improving.
  677. Models often benefit from reducing the learning rate by a factor
  678. of 2-10 once learning stagnates. This scheduler reads a metrics
  679. quantity and if no improvement is seen for a 'patience' number
  680. of epochs, the learning rate is reduced.
  681. Args:
  682. optimizer (Optimizer): Wrapped optimizer.
  683. mode (str): One of `min`, `max`. In `min` mode, lr will
  684. be reduced when the quantity monitored has stopped
  685. decreasing; in `max` mode it will be reduced when the
  686. quantity monitored has stopped increasing. Default: 'min'.
  687. factor (float): Factor by which the learning rate will be
  688. reduced. new_lr = lr * factor. Default: 0.1.
  689. patience (int): Number of epochs with no improvement after
  690. which learning rate will be reduced. For example, if
  691. `patience = 2`, then we will ignore the first 2 epochs
  692. with no improvement, and will only decrease the LR after the
  693. 3rd epoch if the loss still hasn't improved then.
  694. Default: 10.
  695. threshold (float): Threshold for measuring the new optimum,
  696. to only focus on significant changes. Default: 1e-4.
  697. threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
  698. dynamic_threshold = best * ( 1 + threshold ) in 'max'
  699. mode or best * ( 1 - threshold ) in `min` mode.
  700. In `abs` mode, dynamic_threshold = best + threshold in
  701. `max` mode or best - threshold in `min` mode. Default: 'rel'.
  702. cooldown (int): Number of epochs to wait before resuming
  703. normal operation after lr has been reduced. Default: 0.
  704. min_lr (float or list): A scalar or a list of scalars. A
  705. lower bound on the learning rate of all param groups
  706. or each group respectively. Default: 0.
  707. eps (float): Minimal decay applied to lr. If the difference
  708. between new and old lr is smaller than eps, the update is
  709. ignored. Default: 1e-8.
  710. verbose (bool): If ``True``, prints a message to stdout for
  711. each update. Default: ``False``.
  712. Example:
  713. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  714. >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
  715. >>> for epoch in range(10):
  716. >>> train(...)
  717. >>> val_loss = validate(...)
  718. >>> # Note that step should be called after validate()
  719. >>> scheduler.step(val_loss)
  720. """
  721. def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
  722. threshold=1e-4, threshold_mode='rel', cooldown=0,
  723. min_lr=0, eps=1e-8, verbose=False):
  724. if factor >= 1.0:
  725. raise ValueError('Factor should be < 1.0.')
  726. self.factor = factor
  727. # Attach optimizer
  728. if not isinstance(optimizer, Optimizer):
  729. raise TypeError('{} is not an Optimizer'.format(
  730. type(optimizer).__name__))
  731. self.optimizer = optimizer
  732. if isinstance(min_lr, list) or isinstance(min_lr, tuple):
  733. if len(min_lr) != len(optimizer.param_groups):
  734. raise ValueError("expected {} min_lrs, got {}".format(
  735. len(optimizer.param_groups), len(min_lr)))
  736. self.min_lrs = list(min_lr)
  737. else:
  738. self.min_lrs = [min_lr] * len(optimizer.param_groups)
  739. self.patience = patience
  740. self.verbose = verbose
  741. self.cooldown = cooldown
  742. self.cooldown_counter = 0
  743. self.mode = mode
  744. self.threshold = threshold
  745. self.threshold_mode = threshold_mode
  746. self.best = None
  747. self.num_bad_epochs = None
  748. self.mode_worse = None # the worse value for the chosen mode
  749. self.eps = eps
  750. self.last_epoch = 0
  751. self._init_is_better(mode=mode, threshold=threshold,
  752. threshold_mode=threshold_mode)
  753. self._reset()
  754. def _reset(self):
  755. """Resets num_bad_epochs counter and cooldown counter."""
  756. self.best = self.mode_worse
  757. self.cooldown_counter = 0
  758. self.num_bad_epochs = 0
  759. def step(self, metrics, epoch=None):
  760. # convert `metrics` to float, in case it's a zero-dim Tensor
  761. current = float(metrics)
  762. if epoch is None:
  763. epoch = self.last_epoch + 1
  764. else:
  765. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  766. self.last_epoch = epoch
  767. if self.is_better(current, self.best):
  768. self.best = current
  769. self.num_bad_epochs = 0
  770. else:
  771. self.num_bad_epochs += 1
  772. if self.in_cooldown:
  773. self.cooldown_counter -= 1
  774. self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
  775. if self.num_bad_epochs > self.patience:
  776. self._reduce_lr(epoch)
  777. self.cooldown_counter = self.cooldown
  778. self.num_bad_epochs = 0
  779. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  780. def _reduce_lr(self, epoch):
  781. for i, param_group in enumerate(self.optimizer.param_groups):
  782. old_lr = float(param_group['lr'])
  783. new_lr = max(old_lr * self.factor, self.min_lrs[i])
  784. if old_lr - new_lr > self.eps:
  785. param_group['lr'] = new_lr
  786. if self.verbose:
  787. epoch_str = ("%.2f" if isinstance(epoch, float) else
  788. "%.5d") % epoch
  789. print('Epoch {}: reducing learning rate'
  790. ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
  791. @property
  792. def in_cooldown(self):
  793. return self.cooldown_counter > 0
  794. def is_better(self, a, best):
  795. if self.mode == 'min' and self.threshold_mode == 'rel':
  796. rel_epsilon = 1. - self.threshold
  797. return a < best * rel_epsilon
  798. elif self.mode == 'min' and self.threshold_mode == 'abs':
  799. return a < best - self.threshold
  800. elif self.mode == 'max' and self.threshold_mode == 'rel':
  801. rel_epsilon = self.threshold + 1.
  802. return a > best * rel_epsilon
  803. else: # mode == 'max' and epsilon_mode == 'abs':
  804. return a > best + self.threshold
  805. def _init_is_better(self, mode, threshold, threshold_mode):
  806. if mode not in {'min', 'max'}:
  807. raise ValueError('mode ' + mode + ' is unknown!')
  808. if threshold_mode not in {'rel', 'abs'}:
  809. raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
  810. if mode == 'min':
  811. self.mode_worse = inf
  812. else: # mode == 'max':
  813. self.mode_worse = -inf
  814. self.mode = mode
  815. self.threshold = threshold
  816. self.threshold_mode = threshold_mode
  817. def state_dict(self):
  818. return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
  819. def load_state_dict(self, state_dict):
  820. self.__dict__.update(state_dict)
  821. self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
  822. class CyclicLR(_LRScheduler):
  823. r"""Sets the learning rate of each parameter group according to
  824. cyclical learning rate policy (CLR). The policy cycles the learning
  825. rate between two boundaries with a constant frequency, as detailed in
  826. the paper `Cyclical Learning Rates for Training Neural Networks`_.
  827. The distance between the two boundaries can be scaled on a per-iteration
  828. or per-cycle basis.
  829. Cyclical learning rate policy changes the learning rate after every batch.
  830. `step` should be called after a batch has been used for training.
  831. This class has three built-in policies, as put forth in the paper:
  832. * "triangular": A basic triangular cycle without amplitude scaling.
  833. * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
  834. * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
  835. at each cycle iteration.
  836. This implementation was adapted from the github repo: `bckenstler/CLR`_
  837. Args:
  838. optimizer (Optimizer): Wrapped optimizer.
  839. base_lr (float or list): Initial learning rate which is the
  840. lower boundary in the cycle for each parameter group.
  841. max_lr (float or list): Upper learning rate boundaries in the cycle
  842. for each parameter group. Functionally,
  843. it defines the cycle amplitude (max_lr - base_lr).
  844. The lr at any cycle is the sum of base_lr
  845. and some scaling of the amplitude; therefore
  846. max_lr may not actually be reached depending on
  847. scaling function.
  848. step_size_up (int): Number of training iterations in the
  849. increasing half of a cycle. Default: 2000
  850. step_size_down (int): Number of training iterations in the
  851. decreasing half of a cycle. If step_size_down is None,
  852. it is set to step_size_up. Default: None
  853. mode (str): One of {triangular, triangular2, exp_range}.
  854. Values correspond to policies detailed above.
  855. If scale_fn is not None, this argument is ignored.
  856. Default: 'triangular'
  857. gamma (float): Constant in 'exp_range' scaling function:
  858. gamma**(cycle iterations)
  859. Default: 1.0
  860. scale_fn (function): Custom scaling policy defined by a single
  861. argument lambda function, where
  862. 0 <= scale_fn(x) <= 1 for all x >= 0.
  863. If specified, then 'mode' is ignored.
  864. Default: None
  865. scale_mode (str): {'cycle', 'iterations'}.
  866. Defines whether scale_fn is evaluated on
  867. cycle number or cycle iterations (training
  868. iterations since start of cycle).
  869. Default: 'cycle'
  870. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  871. to learning rate between 'base_momentum' and 'max_momentum'.
  872. Default: True
  873. base_momentum (float or list): Lower momentum boundaries in the cycle
  874. for each parameter group. Note that momentum is cycled inversely
  875. to learning rate; at the peak of a cycle, momentum is
  876. 'base_momentum' and learning rate is 'max_lr'.
  877. Default: 0.8
  878. max_momentum (float or list): Upper momentum boundaries in the cycle
  879. for each parameter group. Functionally,
  880. it defines the cycle amplitude (max_momentum - base_momentum).
  881. The momentum at any cycle is the difference of max_momentum
  882. and some scaling of the amplitude; therefore
  883. base_momentum may not actually be reached depending on
  884. scaling function. Note that momentum is cycled inversely
  885. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  886. and learning rate is 'base_lr'
  887. Default: 0.9
  888. last_epoch (int): The index of the last batch. This parameter is used when
  889. resuming a training job. Since `step()` should be invoked after each
  890. batch instead of after each epoch, this number represents the total
  891. number of *batches* computed, not the total number of epochs computed.
  892. When last_epoch=-1, the schedule is started from the beginning.
  893. Default: -1
  894. verbose (bool): If ``True``, prints a message to stdout for
  895. each update. Default: ``False``.
  896. Example:
  897. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  898. >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
  899. >>> data_loader = torch.utils.data.DataLoader(...)
  900. >>> for epoch in range(10):
  901. >>> for batch in data_loader:
  902. >>> train_batch(...)
  903. >>> scheduler.step()
  904. .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
  905. .. _bckenstler/CLR: https://github.com/bckenstler/CLR
  906. """
  907. def __init__(self,
  908. optimizer,
  909. base_lr,
  910. max_lr,
  911. step_size_up=2000,
  912. step_size_down=None,
  913. mode='triangular',
  914. gamma=1.,
  915. scale_fn=None,
  916. scale_mode='cycle',
  917. cycle_momentum=True,
  918. base_momentum=0.8,
  919. max_momentum=0.9,
  920. last_epoch=-1,
  921. verbose=False):
  922. # Attach optimizer
  923. if not isinstance(optimizer, Optimizer):
  924. raise TypeError('{} is not an Optimizer'.format(
  925. type(optimizer).__name__))
  926. self.optimizer = optimizer
  927. base_lrs = self._format_param('base_lr', optimizer, base_lr)
  928. if last_epoch == -1:
  929. for lr, group in zip(base_lrs, optimizer.param_groups):
  930. group['lr'] = lr
  931. self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
  932. step_size_up = float(step_size_up)
  933. step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
  934. self.total_size = step_size_up + step_size_down
  935. self.step_ratio = step_size_up / self.total_size
  936. if mode not in ['triangular', 'triangular2', 'exp_range'] \
  937. and scale_fn is None:
  938. raise ValueError('mode is invalid and scale_fn is None')
  939. self.mode = mode
  940. self.gamma = gamma
  941. if scale_fn is None:
  942. if self.mode == 'triangular':
  943. self.scale_fn = self._triangular_scale_fn
  944. self.scale_mode = 'cycle'
  945. elif self.mode == 'triangular2':
  946. self.scale_fn = self._triangular2_scale_fn
  947. self.scale_mode = 'cycle'
  948. elif self.mode == 'exp_range':
  949. self.scale_fn = self._exp_range_scale_fn
  950. self.scale_mode = 'iterations'
  951. else:
  952. self.scale_fn = scale_fn
  953. self.scale_mode = scale_mode
  954. self.cycle_momentum = cycle_momentum
  955. if cycle_momentum:
  956. if 'momentum' not in optimizer.defaults:
  957. raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
  958. base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
  959. if last_epoch == -1:
  960. for momentum, group in zip(base_momentums, optimizer.param_groups):
  961. group['momentum'] = momentum
  962. self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
  963. self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
  964. super(CyclicLR, self).__init__(optimizer, last_epoch, verbose)
  965. self.base_lrs = base_lrs
  966. def _format_param(self, name, optimizer, param):
  967. """Return correctly formatted lr/momentum for each param group."""
  968. if isinstance(param, (list, tuple)):
  969. if len(param) != len(optimizer.param_groups):
  970. raise ValueError("expected {} values for {}, got {}".format(
  971. len(optimizer.param_groups), name, len(param)))
  972. return param
  973. else:
  974. return [param] * len(optimizer.param_groups)
  975. def _triangular_scale_fn(self, x):
  976. return 1.
  977. def _triangular2_scale_fn(self, x):
  978. return 1 / (2. ** (x - 1))
  979. def _exp_range_scale_fn(self, x):
  980. return self.gamma**(x)
  981. def get_lr(self):
  982. """Calculates the learning rate at batch index. This function treats
  983. `self.last_epoch` as the last batch index.
  984. If `self.cycle_momentum` is ``True``, this function has a side effect of
  985. updating the optimizer's momentum.
  986. """
  987. if not self._get_lr_called_within_step:
  988. warnings.warn("To get the last learning rate computed by the scheduler, "
  989. "please use `get_last_lr()`.", UserWarning)
  990. cycle = math.floor(1 + self.last_epoch / self.total_size)
  991. x = 1. + self.last_epoch / self.total_size - cycle
  992. if x <= self.step_ratio:
  993. scale_factor = x / self.step_ratio
  994. else:
  995. scale_factor = (x - 1) / (self.step_ratio - 1)
  996. lrs = []
  997. for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
  998. base_height = (max_lr - base_lr) * scale_factor
  999. if self.scale_mode == 'cycle':
  1000. lr = base_lr + base_height * self.scale_fn(cycle)
  1001. else:
  1002. lr = base_lr + base_height * self.scale_fn(self.last_epoch)
  1003. lrs.append(lr)
  1004. if self.cycle_momentum:
  1005. momentums = []
  1006. for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
  1007. base_height = (max_momentum - base_momentum) * scale_factor
  1008. if self.scale_mode == 'cycle':
  1009. momentum = max_momentum - base_height * self.scale_fn(cycle)
  1010. else:
  1011. momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
  1012. momentums.append(momentum)
  1013. for param_group, momentum in zip(self.optimizer.param_groups, momentums):
  1014. param_group['momentum'] = momentum
  1015. return lrs
  1016. class CosineAnnealingWarmRestarts(_LRScheduler):
  1017. r"""Set the learning rate of each parameter group using a cosine annealing
  1018. schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
  1019. is the number of epochs since the last restart and :math:`T_{i}` is the number
  1020. of epochs between two warm restarts in SGDR:
  1021. .. math::
  1022. \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
  1023. \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
  1024. When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
  1025. When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
  1026. It has been proposed in
  1027. `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
  1028. Args:
  1029. optimizer (Optimizer): Wrapped optimizer.
  1030. T_0 (int): Number of iterations for the first restart.
  1031. T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
  1032. eta_min (float, optional): Minimum learning rate. Default: 0.
  1033. last_epoch (int, optional): The index of last epoch. Default: -1.
  1034. verbose (bool): If ``True``, prints a message to stdout for
  1035. each update. Default: ``False``.
  1036. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  1037. https://arxiv.org/abs/1608.03983
  1038. """
  1039. def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
  1040. if T_0 <= 0 or not isinstance(T_0, int):
  1041. raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
  1042. if T_mult < 1 or not isinstance(T_mult, int):
  1043. raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
  1044. self.T_0 = T_0
  1045. self.T_i = T_0
  1046. self.T_mult = T_mult
  1047. self.eta_min = eta_min
  1048. self.T_cur = last_epoch
  1049. super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)
  1050. def get_lr(self):
  1051. if not self._get_lr_called_within_step:
  1052. warnings.warn("To get the last learning rate computed by the scheduler, "
  1053. "please use `get_last_lr()`.", UserWarning)
  1054. return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
  1055. for base_lr in self.base_lrs]
  1056. def step(self, epoch=None):
  1057. """Step could be called after every batch update
  1058. Example:
  1059. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1060. >>> iters = len(dataloader)
  1061. >>> for epoch in range(20):
  1062. >>> for i, sample in enumerate(dataloader):
  1063. >>> inputs, labels = sample['inputs'], sample['labels']
  1064. >>> optimizer.zero_grad()
  1065. >>> outputs = net(inputs)
  1066. >>> loss = criterion(outputs, labels)
  1067. >>> loss.backward()
  1068. >>> optimizer.step()
  1069. >>> scheduler.step(epoch + i / iters)
  1070. This function can be called in an interleaved way.
  1071. Example:
  1072. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1073. >>> for epoch in range(20):
  1074. >>> scheduler.step()
  1075. >>> scheduler.step(26)
  1076. >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
  1077. """
  1078. if epoch is None and self.last_epoch < 0:
  1079. epoch = 0
  1080. if epoch is None:
  1081. epoch = self.last_epoch + 1
  1082. self.T_cur = self.T_cur + 1
  1083. if self.T_cur >= self.T_i:
  1084. self.T_cur = self.T_cur - self.T_i
  1085. self.T_i = self.T_i * self.T_mult
  1086. else:
  1087. if epoch < 0:
  1088. raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
  1089. if epoch >= self.T_0:
  1090. if self.T_mult == 1:
  1091. self.T_cur = epoch % self.T_0
  1092. else:
  1093. n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
  1094. self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
  1095. self.T_i = self.T_0 * self.T_mult ** (n)
  1096. else:
  1097. self.T_i = self.T_0
  1098. self.T_cur = epoch
  1099. self.last_epoch = math.floor(epoch)
  1100. class _enable_get_lr_call:
  1101. def __init__(self, o):
  1102. self.o = o
  1103. def __enter__(self):
  1104. self.o._get_lr_called_within_step = True
  1105. return self
  1106. def __exit__(self, type, value, traceback):
  1107. self.o._get_lr_called_within_step = False
  1108. return self
  1109. with _enable_get_lr_call(self):
  1110. for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
  1111. param_group, lr = data
  1112. param_group['lr'] = lr
  1113. self.print_lr(self.verbose, i, lr, epoch)
  1114. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  1115. class OneCycleLR(_LRScheduler):
  1116. r"""Sets the learning rate of each parameter group according to the
  1117. 1cycle learning rate policy. The 1cycle policy anneals the learning
  1118. rate from an initial learning rate to some maximum learning rate and then
  1119. from that maximum learning rate to some minimum learning rate much lower
  1120. than the initial learning rate.
  1121. This policy was initially described in the paper `Super-Convergence:
  1122. Very Fast Training of Neural Networks Using Large Learning Rates`_.
  1123. The 1cycle learning rate policy changes the learning rate after every batch.
  1124. `step` should be called after a batch has been used for training.
  1125. This scheduler is not chainable.
  1126. Note also that the total number of steps in the cycle can be determined in one
  1127. of two ways (listed in order of precedence):
  1128. #. A value for total_steps is explicitly provided.
  1129. #. A number of epochs (epochs) and a number of steps per epoch
  1130. (steps_per_epoch) are provided.
  1131. In this case, the number of total steps is inferred by
  1132. total_steps = epochs * steps_per_epoch
  1133. You must either provide a value for total_steps or provide a value for both
  1134. epochs and steps_per_epoch.
  1135. The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
  1136. claims that "unpublished work has shown even better results by using only two phases". To
  1137. mimic the behaviour of the original paper instead, set ``three_phase=True``.
  1138. Args:
  1139. optimizer (Optimizer): Wrapped optimizer.
  1140. max_lr (float or list): Upper learning rate boundaries in the cycle
  1141. for each parameter group.
  1142. total_steps (int): The total number of steps in the cycle. Note that
  1143. if a value is not provided here, then it must be inferred by providing
  1144. a value for epochs and steps_per_epoch.
  1145. Default: None
  1146. epochs (int): The number of epochs to train for. This is used along
  1147. with steps_per_epoch in order to infer the total number of steps in the cycle
  1148. if a value for total_steps is not provided.
  1149. Default: None
  1150. steps_per_epoch (int): The number of steps per epoch to train for. This is
  1151. used along with epochs in order to infer the total number of steps in the
  1152. cycle if a value for total_steps is not provided.
  1153. Default: None
  1154. pct_start (float): The percentage of the cycle (in number of steps) spent
  1155. increasing the learning rate.
  1156. Default: 0.3
  1157. anneal_strategy (str): {'cos', 'linear'}
  1158. Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
  1159. linear annealing.
  1160. Default: 'cos'
  1161. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  1162. to learning rate between 'base_momentum' and 'max_momentum'.
  1163. Default: True
  1164. base_momentum (float or list): Lower momentum boundaries in the cycle
  1165. for each parameter group. Note that momentum is cycled inversely
  1166. to learning rate; at the peak of a cycle, momentum is
  1167. 'base_momentum' and learning rate is 'max_lr'.
  1168. Default: 0.85
  1169. max_momentum (float or list): Upper momentum boundaries in the cycle
  1170. for each parameter group. Functionally,
  1171. it defines the cycle amplitude (max_momentum - base_momentum).
  1172. Note that momentum is cycled inversely
  1173. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  1174. and learning rate is 'base_lr'
  1175. Default: 0.95
  1176. div_factor (float): Determines the initial learning rate via
  1177. initial_lr = max_lr/div_factor
  1178. Default: 25
  1179. final_div_factor (float): Determines the minimum learning rate via
  1180. min_lr = initial_lr/final_div_factor
  1181. Default: 1e4
  1182. three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
  1183. learning rate according to 'final_div_factor' instead of modifying the second
  1184. phase (the first two phases will be symmetrical about the step indicated by
  1185. 'pct_start').
  1186. last_epoch (int): The index of the last batch. This parameter is used when
  1187. resuming a training job. Since `step()` should be invoked after each
  1188. batch instead of after each epoch, this number represents the total
  1189. number of *batches* computed, not the total number of epochs computed.
  1190. When last_epoch=-1, the schedule is started from the beginning.
  1191. Default: -1
  1192. verbose (bool): If ``True``, prints a message to stdout for
  1193. each update. Default: ``False``.
  1194. Example:
  1195. >>> data_loader = torch.utils.data.DataLoader(...)
  1196. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  1197. >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
  1198. >>> for epoch in range(10):
  1199. >>> for batch in data_loader:
  1200. >>> train_batch(...)
  1201. >>> scheduler.step()
  1202. .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
  1203. https://arxiv.org/abs/1708.07120
  1204. """
  1205. def __init__(self,
  1206. optimizer,
  1207. max_lr,
  1208. total_steps=None,
  1209. epochs=None,
  1210. steps_per_epoch=None,
  1211. pct_start=0.3,
  1212. anneal_strategy='cos',
  1213. cycle_momentum=True,
  1214. base_momentum=0.85,
  1215. max_momentum=0.95,
  1216. div_factor=25.,
  1217. final_div_factor=1e4,
  1218. three_phase=False,
  1219. last_epoch=-1,
  1220. verbose=False):
  1221. # Validate optimizer
  1222. if not isinstance(optimizer, Optimizer):
  1223. raise TypeError('{} is not an Optimizer'.format(
  1224. type(optimizer).__name__))
  1225. self.optimizer = optimizer
  1226. # Validate total_steps
  1227. if total_steps is None and epochs is None and steps_per_epoch is None:
  1228. raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
  1229. elif total_steps is not None:
  1230. if total_steps <= 0 or not isinstance(total_steps, int):
  1231. raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
  1232. self.total_steps = total_steps
  1233. else:
  1234. if epochs <= 0 or not isinstance(epochs, int):
  1235. raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
  1236. if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
  1237. raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
  1238. self.total_steps = epochs * steps_per_epoch
  1239. if three_phase:
  1240. self._schedule_phases = [
  1241. {
  1242. 'end_step': float(pct_start * self.total_steps) - 1,
  1243. 'start_lr': 'initial_lr',
  1244. 'end_lr': 'max_lr',
  1245. 'start_momentum': 'max_momentum',
  1246. 'end_momentum': 'base_momentum',
  1247. },
  1248. {
  1249. 'end_step': float(2 * pct_start * self.total_steps) - 2,
  1250. 'start_lr': 'max_lr',
  1251. 'end_lr': 'initial_lr',
  1252. 'start_momentum': 'base_momentum',
  1253. 'end_momentum': 'max_momentum',
  1254. },
  1255. {
  1256. 'end_step': self.total_steps - 1,
  1257. 'start_lr': 'initial_lr',
  1258. 'end_lr': 'min_lr',
  1259. 'start_momentum': 'max_momentum',
  1260. 'end_momentum': 'max_momentum',
  1261. },
  1262. ]
  1263. else:
  1264. self._schedule_phases = [
  1265. {
  1266. 'end_step': float(pct_start * self.total_steps) - 1,
  1267. 'start_lr': 'initial_lr',
  1268. 'end_lr': 'max_lr',
  1269. 'start_momentum': 'max_momentum',
  1270. 'end_momentum': 'base_momentum',
  1271. },
  1272. {
  1273. 'end_step': self.total_steps - 1,
  1274. 'start_lr': 'max_lr',
  1275. 'end_lr': 'min_lr',
  1276. 'start_momentum': 'base_momentum',
  1277. 'end_momentum': 'max_momentum',
  1278. },
  1279. ]
  1280. # Validate pct_start
  1281. if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
  1282. raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
  1283. # Validate anneal_strategy
  1284. if anneal_strategy not in ['cos', 'linear']:
  1285. raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
  1286. elif anneal_strategy == 'cos':
  1287. self.anneal_func = self._annealing_cos
  1288. elif anneal_strategy == 'linear':
  1289. self.anneal_func = self._annealing_linear
  1290. # Initialize learning rate variables
  1291. max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
  1292. if last_epoch == -1:
  1293. for idx, group in enumerate(self.optimizer.param_groups):
  1294. group['initial_lr'] = max_lrs[idx] / div_factor
  1295. group['max_lr'] = max_lrs[idx]
  1296. group['min_lr'] = group['initial_lr'] / final_div_factor
  1297. # Initialize momentum variables
  1298. self.cycle_momentum = cycle_momentum
  1299. if self.cycle_momentum:
  1300. if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
  1301. raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
  1302. self.use_beta1 = 'betas' in self.optimizer.defaults
  1303. max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
  1304. base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
  1305. if last_epoch == -1:
  1306. for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
  1307. if self.use_beta1:
  1308. _, beta2 = group['betas']
  1309. group['betas'] = (m_momentum, beta2)
  1310. else:
  1311. group['momentum'] = m_momentum
  1312. group['max_momentum'] = m_momentum
  1313. group['base_momentum'] = b_momentum
  1314. super(OneCycleLR, self).__init__(optimizer, last_epoch, verbose)
  1315. def _format_param(self, name, optimizer, param):
  1316. """Return correctly formatted lr/momentum for each param group."""
  1317. if isinstance(param, (list, tuple)):
  1318. if len(param) != len(optimizer.param_groups):
  1319. raise ValueError("expected {} values for {}, got {}".format(
  1320. len(optimizer.param_groups), name, len(param)))
  1321. return param
  1322. else:
  1323. return [param] * len(optimizer.param_groups)
  1324. def _annealing_cos(self, start, end, pct):
  1325. "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  1326. cos_out = math.cos(math.pi * pct) + 1
  1327. return end + (start - end) / 2.0 * cos_out
  1328. def _annealing_linear(self, start, end, pct):
  1329. "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  1330. return (end - start) * pct + start
  1331. def get_lr(self):
  1332. if not self._get_lr_called_within_step:
  1333. warnings.warn("To get the last learning rate computed by the scheduler, "
  1334. "please use `get_last_lr()`.", UserWarning)
  1335. lrs = []
  1336. step_num = self.last_epoch
  1337. if step_num > self.total_steps:
  1338. raise ValueError("Tried to step {} times. The specified number of total steps is {}"
  1339. .format(step_num + 1, self.total_steps))
  1340. for group in self.optimizer.param_groups:
  1341. start_step = 0
  1342. for i, phase in enumerate(self._schedule_phases):
  1343. end_step = phase['end_step']
  1344. if step_num <= end_step or i == len(self._schedule_phases) - 1:
  1345. pct = (step_num - start_step) / (end_step - start_step)
  1346. computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
  1347. if self.cycle_momentum:
  1348. computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
  1349. break
  1350. start_step = phase['end_step']
  1351. lrs.append(computed_lr)
  1352. if self.cycle_momentum:
  1353. if self.use_beta1:
  1354. _, beta2 = group['betas']
  1355. group['betas'] = (computed_momentum, beta2)
  1356. else:
  1357. group['momentum'] = computed_momentum
  1358. return lrs