homotopy_weight.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # @package homotopy_weight
  2. # Module caffe2.fb.python.layers.homotopy_weight
  3. from caffe2.python import core, schema
  4. from caffe2.python.layers.layers import ModelLayer
  5. import numpy as np
  6. import logging
  7. logger = logging.getLogger(__name__)
  8. '''
  9. Homotopy Weighting between two weights x, y by doing:
  10. alpha x + beta y
  11. where alpha is a decreasing scalar parameter ranging from [min, max] (default,
  12. [0, 1]), and alpha + beta = max + min, which means that beta is increasing in
  13. the range [min, max];
  14. Homotopy methods first solves an "easy" problem (one to which the solution is
  15. well known), and is gradually transformed into the target problem
  16. '''
  17. class HomotopyWeight(ModelLayer):
  18. def __init__(
  19. self,
  20. model,
  21. input_record,
  22. name='homotopy_weight',
  23. min_weight=0.,
  24. max_weight=1.,
  25. half_life=1e6,
  26. quad_life=3e6,
  27. atomic_iter=None,
  28. **kwargs
  29. ):
  30. super(HomotopyWeight,
  31. self).__init__(model, name, input_record, **kwargs)
  32. self.output_schema = schema.Scalar(
  33. np.float32, self.get_next_blob_reference('homotopy_weight')
  34. )
  35. data = self.input_record.field_blobs()
  36. assert len(data) == 2
  37. self.x = data[0]
  38. self.y = data[1]
  39. # TODO: currently model building does not have access to iter counter or
  40. # learning rate; it's added at optimization time;
  41. self.use_external_iter = (atomic_iter is not None)
  42. self.atomic_iter = (
  43. atomic_iter if self.use_external_iter else self.create_atomic_iter()
  44. )
  45. # to map lr to [min, max]; alpha = scale * lr + offset
  46. assert max_weight > min_weight
  47. self.scale = float(max_weight - min_weight)
  48. self.offset = self.model.add_global_constant(
  49. '%s_offset_1dfloat' % self.name, float(min_weight)
  50. )
  51. self.gamma, self.power = self.solve_inv_lr_params(half_life, quad_life)
  52. def solve_inv_lr_params(self, half_life, quad_life):
  53. # ensure that the gamma, power is solvable
  54. assert half_life > 0
  55. # convex monotonically decreasing
  56. assert quad_life > 2 * half_life
  57. t = float(quad_life) / float(half_life)
  58. x = t * (1.0 + np.sqrt(2.0)) / 2.0 - np.sqrt(2.0)
  59. gamma = (x - 1.0) / float(half_life)
  60. power = np.log(2.0) / np.log(x)
  61. logger.info(
  62. 'homotopy_weighting: found lr param: gamma=%g, power=%g' %
  63. (gamma, power)
  64. )
  65. return gamma, power
  66. def create_atomic_iter(self):
  67. self.mutex = self.create_param(
  68. param_name=('%s_mutex' % self.name),
  69. shape=None,
  70. initializer=('CreateMutex', ),
  71. optimizer=self.model.NoOptim,
  72. )
  73. self.atomic_iter = self.create_param(
  74. param_name=('%s_atomic_iter' % self.name),
  75. shape=[1],
  76. initializer=(
  77. 'ConstantFill', {
  78. 'value': 0,
  79. 'dtype': core.DataType.INT64
  80. }
  81. ),
  82. optimizer=self.model.NoOptim,
  83. )
  84. return self.atomic_iter
  85. def update_weight(self, net):
  86. alpha = net.NextScopedBlob('alpha')
  87. beta = net.NextScopedBlob('beta')
  88. lr = net.NextScopedBlob('lr')
  89. comp_lr = net.NextScopedBlob('complementary_lr')
  90. scaled_lr = net.NextScopedBlob('scaled_lr')
  91. scaled_comp_lr = net.NextScopedBlob('scaled_complementary_lr')
  92. if not self.use_external_iter:
  93. net.AtomicIter([self.mutex, self.atomic_iter], [self.atomic_iter])
  94. net.LearningRate(
  95. [self.atomic_iter],
  96. [lr],
  97. policy='inv',
  98. gamma=self.gamma,
  99. power=self.power,
  100. base_lr=1.0,
  101. )
  102. net.Sub([self.model.global_constants['ONE'], lr], [comp_lr])
  103. net.Scale([lr], [scaled_lr], scale=self.scale)
  104. net.Scale([comp_lr], [scaled_comp_lr], scale=self.scale)
  105. net.Add([scaled_lr, self.offset], [alpha])
  106. net.Add([scaled_comp_lr, self.offset], [beta])
  107. return alpha, beta
  108. def add_ops(self, net):
  109. alpha, beta = self.update_weight(net)
  110. # alpha x + beta y
  111. net.WeightedSum([self.x, alpha, self.y, beta], self.output_schema())