adaptive_weight.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # @package adaptive_weight
  2. # Module caffe2.fb.python.layers.adaptive_weight
  3. import numpy as np
  4. from caffe2.python import core, schema
  5. from caffe2.python.layers.layers import ModelLayer
  6. from caffe2.python.regularizer import BoundedGradientProjection, LogBarrier
  7. """
  8. Implementation of adaptive weighting: https://arxiv.org/pdf/1705.07115.pdf
  9. """
  10. class AdaptiveWeight(ModelLayer):
  11. def __init__(
  12. self,
  13. model,
  14. input_record,
  15. name="adaptive_weight",
  16. optimizer=None,
  17. weights=None,
  18. enable_diagnose=False,
  19. estimation_method="log_std",
  20. pos_optim_method="log_barrier",
  21. reg_lambda=0.1,
  22. **kwargs
  23. ):
  24. super(AdaptiveWeight, self).__init__(model, name, input_record, **kwargs)
  25. self.output_schema = schema.Scalar(
  26. np.float32, self.get_next_blob_reference("adaptive_weight")
  27. )
  28. self.data = self.input_record.field_blobs()
  29. self.num = len(self.data)
  30. self.optimizer = optimizer
  31. if weights is not None:
  32. assert len(weights) == self.num
  33. else:
  34. weights = [1. / self.num for _ in range(self.num)]
  35. assert min(weights) > 0, "initial weights must be positive"
  36. self.weights = np.array(weights).astype(np.float32)
  37. self.estimation_method = str(estimation_method).lower()
  38. # used in positivity-constrained parameterization as when the estimation method
  39. # is inv_var, with optimization method being either log barrier, or grad proj
  40. self.pos_optim_method = str(pos_optim_method).lower()
  41. self.reg_lambda = float(reg_lambda)
  42. self.enable_diagnose = enable_diagnose
  43. self.init_func = getattr(self, self.estimation_method + "_init")
  44. self.weight_func = getattr(self, self.estimation_method + "_weight")
  45. self.reg_func = getattr(self, self.estimation_method + "_reg")
  46. self.init_func()
  47. if self.enable_diagnose:
  48. self.weight_i = [
  49. self.get_next_blob_reference("adaptive_weight_%d" % i)
  50. for i in range(self.num)
  51. ]
  52. for i in range(self.num):
  53. self.model.add_ad_hoc_plot_blob(self.weight_i[i])
  54. def concat_data(self, net):
  55. reshaped = [net.NextScopedBlob("reshaped_data_%d" % i) for i in range(self.num)]
  56. # coerce shape for single real values
  57. for i in range(self.num):
  58. net.Reshape(
  59. [self.data[i]],
  60. [reshaped[i], net.NextScopedBlob("new_shape_%d" % i)],
  61. shape=[1],
  62. )
  63. concated = net.NextScopedBlob("concated_data")
  64. net.Concat(
  65. reshaped, [concated, net.NextScopedBlob("concated_new_shape")], axis=0
  66. )
  67. return concated
  68. def log_std_init(self):
  69. """
  70. mu = 2 log sigma, sigma = standard variance
  71. per task objective:
  72. min 1 / 2 / e^mu X + mu / 2
  73. """
  74. values = np.log(1. / 2. / self.weights)
  75. initializer = (
  76. "GivenTensorFill",
  77. {"values": values, "dtype": core.DataType.FLOAT},
  78. )
  79. self.mu = self.create_param(
  80. param_name="mu",
  81. shape=[self.num],
  82. initializer=initializer,
  83. optimizer=self.optimizer,
  84. )
  85. def log_std_weight(self, x, net, weight):
  86. """
  87. min 1 / 2 / e^mu X + mu / 2
  88. """
  89. mu_neg = net.NextScopedBlob("mu_neg")
  90. net.Negative(self.mu, mu_neg)
  91. mu_neg_exp = net.NextScopedBlob("mu_neg_exp")
  92. net.Exp(mu_neg, mu_neg_exp)
  93. net.Scale(mu_neg_exp, weight, scale=0.5)
  94. def log_std_reg(self, net, reg):
  95. net.Scale(self.mu, reg, scale=0.5)
  96. def inv_var_init(self):
  97. """
  98. k = 1 / variance
  99. per task objective:
  100. min 1 / 2 * k X - 1 / 2 * log k
  101. """
  102. values = 2. * self.weights
  103. initializer = (
  104. "GivenTensorFill",
  105. {"values": values, "dtype": core.DataType.FLOAT},
  106. )
  107. if self.pos_optim_method == "log_barrier":
  108. regularizer = LogBarrier(reg_lambda=self.reg_lambda)
  109. elif self.pos_optim_method == "pos_grad_proj":
  110. regularizer = BoundedGradientProjection(lb=0, left_open=True)
  111. else:
  112. raise TypeError(
  113. "unknown positivity optimization method: {}".format(
  114. self.pos_optim_method
  115. )
  116. )
  117. self.k = self.create_param(
  118. param_name="k",
  119. shape=[self.num],
  120. initializer=initializer,
  121. optimizer=self.optimizer,
  122. regularizer=regularizer,
  123. )
  124. def inv_var_weight(self, x, net, weight):
  125. net.Scale(self.k, weight, scale=0.5)
  126. def inv_var_reg(self, net, reg):
  127. log_k = net.NextScopedBlob("log_k")
  128. net.Log(self.k, log_k)
  129. net.Scale(log_k, reg, scale=-0.5)
  130. def _add_ops_impl(self, net, enable_diagnose):
  131. x = self.concat_data(net)
  132. weight = net.NextScopedBlob("weight")
  133. reg = net.NextScopedBlob("reg")
  134. weighted_x = net.NextScopedBlob("weighted_x")
  135. weighted_x_add_reg = net.NextScopedBlob("weighted_x_add_reg")
  136. self.weight_func(x, net, weight)
  137. self.reg_func(net, reg)
  138. net.Mul([weight, x], weighted_x)
  139. net.Add([weighted_x, reg], weighted_x_add_reg)
  140. net.SumElements(weighted_x_add_reg, self.output_schema())
  141. if enable_diagnose:
  142. for i in range(self.num):
  143. net.Slice(weight, self.weight_i[i], starts=[i], ends=[i + 1])
  144. def add_ops(self, net):
  145. self._add_ops_impl(net, self.enable_diagnose)