dropout.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Module caffe2.python.layers.dropout
  2. from caffe2.python import schema
  3. from caffe2.python.layers.layers import ModelLayer
  4. class Dropout(ModelLayer):
  5. def __init__(
  6. self,
  7. model,
  8. input_record,
  9. name='dropout',
  10. ratio=0.5,
  11. dropout_for_eval=False,
  12. **kwargs):
  13. super(Dropout, self).__init__(model, name, input_record, **kwargs)
  14. assert isinstance(input_record, schema.Scalar), "Incorrect input type"
  15. assert (ratio >= 0 and ratio < 1.0), \
  16. "Expected 0 <= ratio < 1, but got ratio of %s" % ratio
  17. self.output_schema = input_record.clone_schema()
  18. self.output_schema.set_value(self.get_next_blob_reference('output'))
  19. self.dropout_for_eval = dropout_for_eval
  20. self.ratio = ratio
  21. def _add_ops(self, net, is_test):
  22. input_blob = self.input_record.field_blobs()
  23. output_blobs = self.output_schema.field_blobs() \
  24. + [net.NextScopedBlob('d_mask')]
  25. net.Dropout(input_blob,
  26. output_blobs,
  27. ratio=self.ratio,
  28. is_test=is_test)
  29. def add_train_ops(self, net):
  30. self._add_ops(net, is_test=False)
  31. def add_eval_ops(self, net):
  32. self._add_ops(net, is_test=(not self.dropout_for_eval))
  33. def add_ops(self, net):
  34. self.add_eval_ops(net)