sampling_train.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. ## @package sampling_train
  2. # Module caffe2.python.layers.sampling_train
  3. from caffe2.python import schema
  4. from caffe2.python.layers.layers import ModelLayer, get_layer_class
  5. from caffe2.python.layers.sampling_trainable_mixin import SamplingTrainableMixin
  6. class SamplingTrain(ModelLayer):
  7. def __init__(
  8. self,
  9. model,
  10. input_record,
  11. prediction_layer,
  12. output_dims,
  13. subtract_log_odd=True,
  14. name='sampling_train',
  15. **kwargs
  16. ):
  17. super(SamplingTrain, self).__init__(
  18. model, name, input_record, **kwargs
  19. )
  20. layer_class = get_layer_class(prediction_layer)
  21. assert issubclass(layer_class, SamplingTrainableMixin)
  22. assert 'indices' in input_record
  23. assert isinstance(input_record.indices, schema.Scalar),\
  24. "input_record.indices is expected to be a schema.Scalar"
  25. assert 'input' in input_record
  26. self.subtract_log_odd = subtract_log_odd
  27. if self.subtract_log_odd:
  28. assert 'sampling_prob' in input_record
  29. self._prediction_layer = layer_class(
  30. model,
  31. input_record.input,
  32. output_dims=output_dims,
  33. **kwargs
  34. )
  35. self._prediction_layer.train_param_blobs = [
  36. model.net.NextBlob(str(blob) + '_sampled')
  37. for blob in self._prediction_layer.param_blobs
  38. ]
  39. self.params = self._prediction_layer.params
  40. self.output_schema = self._prediction_layer.output_schema
  41. def add_ops(self, net):
  42. self._prediction_layer.add_ops(net)
  43. def add_train_ops(self, net):
  44. for full_blob, sampled_blob in zip(
  45. self._prediction_layer.param_blobs,
  46. self._prediction_layer.train_param_blobs
  47. ):
  48. net.Gather([full_blob, self.input_record.indices()], sampled_blob)
  49. self._prediction_layer.add_train_ops(net)
  50. if not self.subtract_log_odd:
  51. return
  52. log_q = net.Log(self.input_record.sampling_prob(),
  53. net.NextScopedBlob("log_q"))
  54. net.Sub([self.output_schema(), log_q], self.output_schema(),
  55. broadcast=1, use_grad_hack=1)