| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- ## @package reservoir_sampling
- # Module caffe2.python.layers.reservoir_sampling
- from caffe2.python import core, schema
- from caffe2.python.layers.layers import ModelLayer
- class ReservoirSampling(ModelLayer):
- """
- Collect samples from input record w/ reservoir sampling. If you have complex
- data, use PackRecords to pack it before using this layer.
- This layer is not thread safe.
- """
- def __init__(self, model, input_record, num_to_collect,
- name='reservoir_sampling', **kwargs):
- super(ReservoirSampling, self).__init__(
- model, name, input_record, **kwargs)
- assert num_to_collect > 0
- self.num_to_collect = num_to_collect
- self.reservoir = self.create_param(
- param_name='reservoir',
- shape=[0],
- initializer=('ConstantFill',),
- optimizer=model.NoOptim,
- )
- self.num_visited_blob = self.create_param(
- param_name='num_visited',
- shape=[],
- initializer=('ConstantFill', {
- 'value': 0,
- 'dtype': core.DataType.INT64,
- }),
- optimizer=model.NoOptim,
- )
- self.mutex = self.create_param(
- param_name='mutex',
- shape=[],
- initializer=('CreateMutex',),
- optimizer=model.NoOptim,
- )
- self.extra_input_blobs = []
- self.extra_output_blobs = []
- if 'object_id' in input_record:
- object_to_pos = self.create_param(
- param_name='object_to_pos',
- shape=None,
- initializer=('CreateMap', {
- 'key_dtype': core.DataType.INT64,
- 'valued_dtype': core.DataType.INT32,
- }),
- optimizer=model.NoOptim,
- )
- pos_to_object = self.create_param(
- param_name='pos_to_object',
- shape=[0],
- initializer=('ConstantFill', {
- 'value': 0,
- 'dtype': core.DataType.INT64,
- }),
- optimizer=model.NoOptim,
- )
- self.extra_input_blobs.append(input_record.object_id())
- self.extra_input_blobs.extend([object_to_pos, pos_to_object])
- self.extra_output_blobs.extend([object_to_pos, pos_to_object])
- self.output_schema = schema.Struct(
- (
- 'reservoir',
- schema.from_blob_list(input_record.data, [self.reservoir])
- ),
- ('num_visited', schema.Scalar(blob=self.num_visited_blob)),
- ('mutex', schema.Scalar(blob=self.mutex)),
- )
- def add_ops(self, net):
- net.ReservoirSampling(
- [self.reservoir, self.num_visited_blob, self.input_record.data(),
- self.mutex] + self.extra_input_blobs,
- [self.reservoir, self.num_visited_blob] + self.extra_output_blobs,
- num_to_collect=self.num_to_collect,
- )
|