_functional.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. r"""Functional interface"""
  2. import math
  3. from torch import Tensor
  4. from typing import List
  5. from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401
  6. from .adagrad import adagrad, _make_sparse # type: ignore[attr-defined] # noqa: F401
  7. from .adam import adam # type: ignore[attr-defined] # noqa: F401
  8. from .adamw import adamw # type: ignore[attr-defined] # noqa: F401
  9. from .adamax import adamax # type: ignore[attr-defined] # noqa: F401
  10. from .asgd import asgd # type: ignore[attr-defined] # noqa: F401
  11. from .nadam import nadam # type: ignore[attr-defined] # noqa: F401
  12. from .radam import radam # type: ignore[attr-defined] # noqa: F401
  13. from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401
  14. from .rprop import rprop # type: ignore[attr-defined] # noqa: F401
  15. from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
  16. # TODO: use foreach API in optim._functional to do all the computation
  17. def sparse_adam(params: List[Tensor],
  18. grads: List[Tensor],
  19. exp_avgs: List[Tensor],
  20. exp_avg_sqs: List[Tensor],
  21. state_steps: List[int],
  22. *,
  23. eps: float,
  24. beta1: float,
  25. beta2: float,
  26. lr: float):
  27. r"""Functional API that performs Sparse Adam algorithm computation.
  28. See :class:`~torch.optim.SparseAdam` for details.
  29. """
  30. for i, param in enumerate(params):
  31. grad = grads[i]
  32. grad = grad.coalesce() # the update is non-linear so indices must be unique
  33. grad_indices = grad._indices()
  34. grad_values = grad._values()
  35. size = grad.size()
  36. exp_avg = exp_avgs[i]
  37. exp_avg_sq = exp_avg_sqs[i]
  38. step = state_steps[i]
  39. def make_sparse(values):
  40. constructor = grad.new
  41. if grad_indices.dim() == 0 or values.dim() == 0:
  42. return constructor().resize_as_(grad)
  43. return constructor(grad_indices, values, size)
  44. # Decay the first and second moment running average coefficient
  45. # old <- b * old + (1 - b) * new
  46. # <==> old += (1 - b) * (new - old)
  47. old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
  48. exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
  49. exp_avg.add_(make_sparse(exp_avg_update_values))
  50. old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
  51. exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
  52. exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
  53. # Dense addition again is intended, avoiding another sparse_mask
  54. numer = exp_avg_update_values.add_(old_exp_avg_values)
  55. exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
  56. denom = exp_avg_sq_update_values.sqrt_().add_(eps)
  57. del exp_avg_update_values, exp_avg_sq_update_values
  58. bias_correction1 = 1 - beta1 ** step
  59. bias_correction2 = 1 - beta2 ** step
  60. step_size = lr * math.sqrt(bias_correction2) / bias_correction1
  61. param.add_(make_sparse(-step_size * numer.div_(denom)))