| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- ## @package model_helper_api
- # Module caffe2.python.model_helper_api
- import sys
- import copy
- import inspect
- from past.builtins import basestring
- from caffe2.python.model_helper import ModelHelper
- # flake8: noqa
- from caffe2.python.helpers.algebra import *
- from caffe2.python.helpers.arg_scope import *
- from caffe2.python.helpers.array_helpers import *
- from caffe2.python.helpers.control_ops import *
- from caffe2.python.helpers.conv import *
- from caffe2.python.helpers.db_input import *
- from caffe2.python.helpers.dropout import *
- from caffe2.python.helpers.elementwise_linear import *
- from caffe2.python.helpers.fc import *
- from caffe2.python.helpers.nonlinearity import *
- from caffe2.python.helpers.normalization import *
- from caffe2.python.helpers.pooling import *
- from caffe2.python.helpers.quantization import *
- from caffe2.python.helpers.tools import *
- from caffe2.python.helpers.train import *
- class HelperWrapper(object):
- _registry = {
- 'arg_scope': arg_scope,
- 'fc': fc,
- 'packed_fc': packed_fc,
- 'fc_decomp': fc_decomp,
- 'fc_sparse': fc_sparse,
- 'fc_prune': fc_prune,
- 'dropout': dropout,
- 'max_pool': max_pool,
- 'average_pool': average_pool,
- 'max_pool_with_index' : max_pool_with_index,
- 'lrn': lrn,
- 'softmax': softmax,
- 'instance_norm': instance_norm,
- 'spatial_bn': spatial_bn,
- 'spatial_gn': spatial_gn,
- 'moments_with_running_stats': moments_with_running_stats,
- 'relu': relu,
- 'prelu': prelu,
- 'tanh': tanh,
- 'concat': concat,
- 'depth_concat': depth_concat,
- 'sum': sum,
- 'reduce_sum': reduce_sum,
- 'sub': sub,
- 'arg_min': arg_min,
- 'transpose': transpose,
- 'iter': iter,
- 'accuracy': accuracy,
- 'conv': conv,
- 'conv_nd': conv_nd,
- 'conv_transpose': conv_transpose,
- 'group_conv': group_conv,
- 'group_conv_deprecated': group_conv_deprecated,
- 'image_input': image_input,
- 'video_input': video_input,
- 'add_weight_decay': add_weight_decay,
- 'elementwise_linear': elementwise_linear,
- 'layer_norm': layer_norm,
- 'mat_mul' : mat_mul,
- 'batch_mat_mul' : batch_mat_mul,
- 'cond' : cond,
- 'loop' : loop,
- 'db_input' : db_input,
- 'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
- 'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
- }
- def __init__(self, wrapped):
- self.wrapped = wrapped
- def __getattr__(self, helper_name):
- if helper_name not in self._registry:
- raise AttributeError(
- "Helper function {} not "
- "registered.".format(helper_name)
- )
- def scope_wrapper(*args, **kwargs):
- new_kwargs = {}
- if helper_name != 'arg_scope':
- if len(args) > 0 and isinstance(args[0], ModelHelper):
- model = args[0]
- elif 'model' in kwargs:
- model = kwargs['model']
- else:
- raise RuntimeError(
- "The first input of helper function should be model. " \
- "Or you can provide it in kwargs as model=<your_model>.")
- new_kwargs = copy.deepcopy(model.arg_scope)
- func = self._registry[helper_name]
- var_names, _, varkw, _= inspect.getargspec(func)
- if varkw is None:
- # this helper function does not take in random **kwargs
- new_kwargs = {
- var_name: new_kwargs[var_name]
- for var_name in var_names if var_name in new_kwargs
- }
- cur_scope = get_current_scope()
- new_kwargs.update(cur_scope.get(helper_name, {}))
- new_kwargs.update(kwargs)
- return func(*args, **new_kwargs)
- scope_wrapper.__name__ = helper_name
- return scope_wrapper
- def Register(self, helper):
- name = helper.__name__
- if name in self._registry:
- raise AttributeError(
- "Helper {} already exists. Please change your "
- "helper name.".format(name)
- )
- self._registry[name] = helper
- def has_helper(self, helper_or_helper_name):
- helper_name = (
- helper_or_helper_name
- if isinstance(helper_or_helper_name, basestring) else
- helper_or_helper_name.__name__
- )
- return helper_name in self._registry
- # pyre-fixme[6]: incompatible parameter type: expected ModuleType, got HelperWrapper
- sys.modules[__name__] = HelperWrapper(sys.modules[__name__])
|