reservoir_sampling.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. ## @package reservoir_sampling
  2. # Module caffe2.python.layers.reservoir_sampling
  3. from caffe2.python import core, schema
  4. from caffe2.python.layers.layers import ModelLayer
  5. class ReservoirSampling(ModelLayer):
  6. """
  7. Collect samples from input record w/ reservoir sampling. If you have complex
  8. data, use PackRecords to pack it before using this layer.
  9. This layer is not thread safe.
  10. """
  11. def __init__(self, model, input_record, num_to_collect,
  12. name='reservoir_sampling', **kwargs):
  13. super(ReservoirSampling, self).__init__(
  14. model, name, input_record, **kwargs)
  15. assert num_to_collect > 0
  16. self.num_to_collect = num_to_collect
  17. self.reservoir = self.create_param(
  18. param_name='reservoir',
  19. shape=[0],
  20. initializer=('ConstantFill',),
  21. optimizer=model.NoOptim,
  22. )
  23. self.num_visited_blob = self.create_param(
  24. param_name='num_visited',
  25. shape=[],
  26. initializer=('ConstantFill', {
  27. 'value': 0,
  28. 'dtype': core.DataType.INT64,
  29. }),
  30. optimizer=model.NoOptim,
  31. )
  32. self.mutex = self.create_param(
  33. param_name='mutex',
  34. shape=[],
  35. initializer=('CreateMutex',),
  36. optimizer=model.NoOptim,
  37. )
  38. self.extra_input_blobs = []
  39. self.extra_output_blobs = []
  40. if 'object_id' in input_record:
  41. object_to_pos = self.create_param(
  42. param_name='object_to_pos',
  43. shape=None,
  44. initializer=('CreateMap', {
  45. 'key_dtype': core.DataType.INT64,
  46. 'valued_dtype': core.DataType.INT32,
  47. }),
  48. optimizer=model.NoOptim,
  49. )
  50. pos_to_object = self.create_param(
  51. param_name='pos_to_object',
  52. shape=[0],
  53. initializer=('ConstantFill', {
  54. 'value': 0,
  55. 'dtype': core.DataType.INT64,
  56. }),
  57. optimizer=model.NoOptim,
  58. )
  59. self.extra_input_blobs.append(input_record.object_id())
  60. self.extra_input_blobs.extend([object_to_pos, pos_to_object])
  61. self.extra_output_blobs.extend([object_to_pos, pos_to_object])
  62. self.output_schema = schema.Struct(
  63. (
  64. 'reservoir',
  65. schema.from_blob_list(input_record.data, [self.reservoir])
  66. ),
  67. ('num_visited', schema.Scalar(blob=self.num_visited_blob)),
  68. ('mutex', schema.Scalar(blob=self.mutex)),
  69. )
  70. def add_ops(self, net):
  71. net.ReservoirSampling(
  72. [self.reservoir, self.num_visited_blob, self.input_record.data(),
  73. self.mutex] + self.extra_input_blobs,
  74. [self.reservoir, self.num_visited_blob] + self.extra_output_blobs,
  75. num_to_collect=self.num_to_collect,
  76. )