brew.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. ## @package model_helper_api
  2. # Module caffe2.python.model_helper_api
  3. import sys
  4. import copy
  5. import inspect
  6. from past.builtins import basestring
  7. from caffe2.python.model_helper import ModelHelper
  8. # flake8: noqa
  9. from caffe2.python.helpers.algebra import *
  10. from caffe2.python.helpers.arg_scope import *
  11. from caffe2.python.helpers.array_helpers import *
  12. from caffe2.python.helpers.control_ops import *
  13. from caffe2.python.helpers.conv import *
  14. from caffe2.python.helpers.db_input import *
  15. from caffe2.python.helpers.dropout import *
  16. from caffe2.python.helpers.elementwise_linear import *
  17. from caffe2.python.helpers.fc import *
  18. from caffe2.python.helpers.nonlinearity import *
  19. from caffe2.python.helpers.normalization import *
  20. from caffe2.python.helpers.pooling import *
  21. from caffe2.python.helpers.quantization import *
  22. from caffe2.python.helpers.tools import *
  23. from caffe2.python.helpers.train import *
  24. class HelperWrapper(object):
  25. _registry = {
  26. 'arg_scope': arg_scope,
  27. 'fc': fc,
  28. 'packed_fc': packed_fc,
  29. 'fc_decomp': fc_decomp,
  30. 'fc_sparse': fc_sparse,
  31. 'fc_prune': fc_prune,
  32. 'dropout': dropout,
  33. 'max_pool': max_pool,
  34. 'average_pool': average_pool,
  35. 'max_pool_with_index' : max_pool_with_index,
  36. 'lrn': lrn,
  37. 'softmax': softmax,
  38. 'instance_norm': instance_norm,
  39. 'spatial_bn': spatial_bn,
  40. 'spatial_gn': spatial_gn,
  41. 'moments_with_running_stats': moments_with_running_stats,
  42. 'relu': relu,
  43. 'prelu': prelu,
  44. 'tanh': tanh,
  45. 'concat': concat,
  46. 'depth_concat': depth_concat,
  47. 'sum': sum,
  48. 'reduce_sum': reduce_sum,
  49. 'sub': sub,
  50. 'arg_min': arg_min,
  51. 'transpose': transpose,
  52. 'iter': iter,
  53. 'accuracy': accuracy,
  54. 'conv': conv,
  55. 'conv_nd': conv_nd,
  56. 'conv_transpose': conv_transpose,
  57. 'group_conv': group_conv,
  58. 'group_conv_deprecated': group_conv_deprecated,
  59. 'image_input': image_input,
  60. 'video_input': video_input,
  61. 'add_weight_decay': add_weight_decay,
  62. 'elementwise_linear': elementwise_linear,
  63. 'layer_norm': layer_norm,
  64. 'mat_mul' : mat_mul,
  65. 'batch_mat_mul' : batch_mat_mul,
  66. 'cond' : cond,
  67. 'loop' : loop,
  68. 'db_input' : db_input,
  69. 'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
  70. 'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
  71. }
  72. def __init__(self, wrapped):
  73. self.wrapped = wrapped
  74. def __getattr__(self, helper_name):
  75. if helper_name not in self._registry:
  76. raise AttributeError(
  77. "Helper function {} not "
  78. "registered.".format(helper_name)
  79. )
  80. def scope_wrapper(*args, **kwargs):
  81. new_kwargs = {}
  82. if helper_name != 'arg_scope':
  83. if len(args) > 0 and isinstance(args[0], ModelHelper):
  84. model = args[0]
  85. elif 'model' in kwargs:
  86. model = kwargs['model']
  87. else:
  88. raise RuntimeError(
  89. "The first input of helper function should be model. " \
  90. "Or you can provide it in kwargs as model=<your_model>.")
  91. new_kwargs = copy.deepcopy(model.arg_scope)
  92. func = self._registry[helper_name]
  93. var_names, _, varkw, _= inspect.getargspec(func)
  94. if varkw is None:
  95. # this helper function does not take in random **kwargs
  96. new_kwargs = {
  97. var_name: new_kwargs[var_name]
  98. for var_name in var_names if var_name in new_kwargs
  99. }
  100. cur_scope = get_current_scope()
  101. new_kwargs.update(cur_scope.get(helper_name, {}))
  102. new_kwargs.update(kwargs)
  103. return func(*args, **new_kwargs)
  104. scope_wrapper.__name__ = helper_name
  105. return scope_wrapper
  106. def Register(self, helper):
  107. name = helper.__name__
  108. if name in self._registry:
  109. raise AttributeError(
  110. "Helper {} already exists. Please change your "
  111. "helper name.".format(name)
  112. )
  113. self._registry[name] = helper
  114. def has_helper(self, helper_or_helper_name):
  115. helper_name = (
  116. helper_or_helper_name
  117. if isinstance(helper_or_helper_name, basestring) else
  118. helper_or_helper_name.__name__
  119. )
  120. return helper_name in self._registry
  121. # pyre-fixme[6]: incompatible parameter type: expected ModuleType, got HelperWrapper
  122. sys.modules[__name__] = HelperWrapper(sys.modules[__name__])