initializers.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from caffe2.python.core import DataType, BlobReference, ScopedBlobReference
  2. from caffe2.python.modeling.parameter_info import ParameterInfo
  3. class Initializer(object):
  4. '''
  5. This class abstracts out parameter creation. One can come up with a new
  6. Initializer in order to implement more complex parameter initialization logic
  7. '''
  8. def __init__(self, operator_name=None, **kwargs):
  9. self.operator_name = operator_name
  10. self.operator_kwargs = kwargs
  11. def update(self, operator_name, kwargs):
  12. if self.operator_name is not None:
  13. raise Exception("Operator name overwrites are not allowed")
  14. self.operator_name = operator_name
  15. self.operator_kwargs = kwargs
  16. def create_param(self, param_name, init_net, shape):
  17. param = init_net.__getattr__(self.operator_name)(
  18. [], param_name, shape=shape, **self.operator_kwargs)
  19. return ParameterInfo(
  20. param_id=None,
  21. param=param,
  22. shape=shape,
  23. )
  24. class ExternalInitializer(object):
  25. '''
  26. This class is used in cases when the parameter should not be initialized by
  27. the initializer, but rather provided in the workspace when param_init_net is
  28. executed.
  29. Current version is not doing any real sanity checks to the parameter.
  30. '''
  31. def create_param(self, param_name, init_net, shape):
  32. if isinstance(param_name, BlobReference):
  33. param = BlobReference(str(param_name), init_net)
  34. elif isinstance(param_name, str):
  35. param = ScopedBlobReference(param_name, init_net)
  36. else:
  37. raise TypeError("Unsupported type for param_name")
  38. # TODO(amalevich): Add operator that will check param in the workspace
  39. return ParameterInfo(
  40. param_id=None,
  41. param=param,
  42. shape=shape,
  43. )
  44. class PseudoFP16Initializer(Initializer):
  45. '''
  46. Used in cases when the parameter should be used at half (16-bit) precision
  47. for compute purposes (i.e. on the forward and backward pass) but
  48. needs to be stored and optimized at single (32-bit) precision so tiny
  49. gradients with small learning rates don't underflow FP16 precision.
  50. A 32-bit copy of the 16-bit blob is stored in the ParameterInfo.
  51. This is helpful for mixed-precision training, see
  52. https://arxiv.org/abs/1710.03740 for details.
  53. '''
  54. def update(self, operator_name, kwargs):
  55. if self.operator_name is not None:
  56. raise Exception("Operator name overwrites are not allowed")
  57. self.operator_name = operator_name
  58. self.operator_kwargs = kwargs
  59. def create_param(self, param_name, init_net, shape):
  60. # create master fp32 copy
  61. param_fp32 = init_net.__getattr__(self.operator_name)(
  62. [], param_name + "_fp32", shape=shape,
  63. **self.operator_kwargs)
  64. # cast to fp16 copy
  65. param = init_net.FloatToHalf(
  66. param_fp32, param_name)
  67. return ParameterInfo(
  68. param_id=None,
  69. param=param,
  70. shape=shape,
  71. blob_copy={DataType.FLOAT: param_fp32}
  72. )
  73. class ReversePseudoFP16Initializer(Initializer):
  74. '''
  75. Like PseudoFP16Initializer above, except the primary blob is taken to
  76. be the 32-bit precision parameter, and the 16-bit version of the blob
  77. is stored in blob_copy instead.
  78. '''
  79. def update(self, operator_name, kwargs):
  80. if self.operator_name is not None:
  81. raise Exception("Operator name overwrites are not allowed")
  82. self.operator_name = operator_name
  83. self.operator_kwargs = kwargs
  84. def create_param(self, param_name, init_net, shape):
  85. # create master fp32 copy
  86. param_fp32 = init_net.__getattr__(self.operator_name)(
  87. [], param_name, shape=shape,
  88. **self.operator_kwargs)
  89. # cast to fp16 copy
  90. param_fp16 = init_net.FloatToHalf(
  91. param_fp32, param_name + "_fp16")
  92. return ParameterInfo(
  93. param_id=None,
  94. param=param_fp32,
  95. shape=shape,
  96. blob_copy={DataType.FLOAT16: param_fp16}
  97. )
  98. def update_initializer(initializer_class,
  99. operator_name_and_kwargs,
  100. default_operator_name_and_kwargs):
  101. '''
  102. A helper function to convert from operator_name_and_kwargs to new
  103. object of type initializer_class. This function serves two purposes:
  104. 1. Support for custom initialization operators being passed in
  105. 2. Allow user to specify a custom Initializer without overwriting
  106. default operators used for initialization
  107. If initializer_class is None, creates a default initializer using
  108. the Initializer class and operator_name_and_kwargs provided
  109. If operator_name_and_kwargs is None, uses default_operator_name_and_kwargs
  110. returns an instantiated Initializer object
  111. '''
  112. def get_initializer_args():
  113. return (
  114. operator_name_and_kwargs or
  115. default_operator_name_and_kwargs
  116. )
  117. if initializer_class is not None:
  118. init = initializer_class(get_initializer_args()[0],
  119. **get_initializer_args()[1])
  120. else:
  121. init = Initializer(
  122. get_initializer_args()[0],
  123. **get_initializer_args()[1]
  124. )
  125. return init