constant_weight.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. # @package constant_weight
  2. # Module caffe2.fb.python.layers.constant_weight
  3. from caffe2.python import schema
  4. from caffe2.python.layers.layers import ModelLayer
  5. import numpy as np
  6. class ConstantWeight(ModelLayer):
  7. def __init__(
  8. self,
  9. model,
  10. input_record,
  11. weights=None,
  12. name='constant_weight',
  13. **kwargs
  14. ):
  15. super(ConstantWeight,
  16. self).__init__(model, name, input_record, **kwargs)
  17. self.output_schema = schema.Scalar(
  18. np.float32, self.get_next_blob_reference('constant_weight')
  19. )
  20. self.data = self.input_record.field_blobs()
  21. self.num = len(self.data)
  22. weights = (
  23. weights if weights is not None else
  24. [1. / self.num for _ in range(self.num)]
  25. )
  26. assert len(weights) == self.num
  27. self.weights = [
  28. self.model.add_global_constant(
  29. '%s_weight_%d' % (self.name, i), float(weights[i])
  30. ) for i in range(self.num)
  31. ]
  32. def add_ops(self, net):
  33. net.WeightedSum(
  34. [b for x_w_pair in zip(self.data, self.weights) for b in x_w_pair],
  35. self.output_schema()
  36. )