pooling.py 924 B

1234567891011121314151617181920212223242526272829303132333435363738
  1. ## @package pooling
  2. # Module caffe2.python.helpers.pooling
  3. ## @package fc
  4. # Module caffe2.python.helpers.pooling
  5. def max_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW", **kwargs):
  6. """Max pooling"""
  7. if use_cudnn:
  8. kwargs['engine'] = 'CUDNN'
  9. return model.net.MaxPool(blob_in, blob_out, order=order, **kwargs)
  10. def average_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW",
  11. **kwargs):
  12. """Average pooling"""
  13. if use_cudnn:
  14. kwargs['engine'] = 'CUDNN'
  15. return model.net.AveragePool(
  16. blob_in,
  17. blob_out,
  18. order=order,
  19. **kwargs
  20. )
  21. def max_pool_with_index(model, blob_in, blob_out, order="NCHW", **kwargs):
  22. """Max pooling with an explicit index of max position"""
  23. return model.net.MaxPoolWithIndex(
  24. blob_in,
  25. [blob_out, blob_out + "_index"],
  26. order=order,
  27. **kwargs
  28. )[0]