utils.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from typing import Type
  2. from torch import optim
  3. from .functional_adagrad import _FunctionalAdagrad
  4. from .functional_adam import _FunctionalAdam
  5. from .functional_adamw import _FunctionalAdamW
  6. from .functional_sgd import _FunctionalSGD
  7. from .functional_adadelta import _FunctionalAdadelta
  8. from .functional_rmsprop import _FunctionalRMSprop
  9. from .functional_rprop import _FunctionalRprop
  10. from .functional_adamax import _FunctionalAdamax
  11. # dict to map a user passed in optimizer_class to a functional
  12. # optimizer class if we have already defined inside the
  13. # distributed.optim package, this is so that we hide the
  14. # functional optimizer to user and still provide the same API.
  15. functional_optim_map = {
  16. optim.Adagrad: _FunctionalAdagrad,
  17. optim.Adam: _FunctionalAdam,
  18. optim.AdamW: _FunctionalAdamW,
  19. optim.SGD: _FunctionalSGD,
  20. optim.Adadelta: _FunctionalAdadelta,
  21. optim.RMSprop: _FunctionalRMSprop,
  22. optim.Rprop: _FunctionalRprop,
  23. optim.Adamax: _FunctionalAdamax,
  24. }
  25. def as_functional_optim(optim_cls: Type, *args, **kwargs):
  26. try:
  27. functional_cls = functional_optim_map[optim_cls]
  28. except KeyError:
  29. raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
  30. return _create_functional_optim(functional_cls, *args, **kwargs)
  31. def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
  32. return functional_optim_cls(
  33. [],
  34. *args,
  35. **kwargs,
  36. _allow_empty_param_list=True,
  37. )