layer_model_helper.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. # @package layer_model_helper
  2. # Module caffe2.python.layer_model_helper
  3. from caffe2.python import core, model_helper, schema, scope, utils, muji
  4. from caffe2.python.modeling.parameter_info import (
  5. ParameterInfo,
  6. )
  7. from caffe2.python.modeling.parameter_sharing import (
  8. parameter_sharing_context,
  9. )
  10. from caffe2.python.modeling.net_modifier import NetModifier
  11. from caffe2.python.optimizer import get_param_device, Optimizer
  12. from caffe2.python.regularizer import Regularizer, RegularizationBy
  13. from caffe2.python.layers import layers
  14. from future.utils import viewitems, viewvalues
  15. import logging
  16. import numpy as np
  17. import copy
  18. logger = logging.getLogger(__name__)
  19. class LayerModelHelper(model_helper.ModelHelper):
  20. """
  21. Model helper for building models on top of layers abstractions.
  22. Each layer is the abstraction that is higher level than Operator. Layer
  23. is responsible for ownership of it's own parameters and can easily be
  24. instantiated in multiple nets possible with different sets of ops.
  25. As an example: one can easily instantiate predict and train nets from
  26. the same set of layers, where predict net will have subset of the
  27. operators from train net.
  28. """
  29. def __init__(self, name, input_feature_schema, trainer_extra_schema,
  30. keep_blobs=False,
  31. use_attribution=True):
  32. ''' TODO(amalevich): more documnetation on input args
  33. use_attribution:
  34. if True, will generate the atrribution net for feature importance
  35. calculation; Need to turn it to false when FC is quantized as FP16
  36. This attribute access will be consistent with MTML model.
  37. '''
  38. super(LayerModelHelper, self).__init__(name=name)
  39. self._layer_names = set()
  40. self._layers = []
  41. self._param_to_shape = {}
  42. # seed default
  43. self._seed = None
  44. self._sequence_seed = True
  45. # optimizer bookkeeping
  46. self.param_to_optim = {}
  47. self.param_to_reg = {}
  48. self._default_optimizer = None
  49. self._loss = None
  50. self._prediction = []
  51. self._output_schema = None
  52. self._post_grad_net_modifiers = []
  53. self._final_net_modifiers = []
  54. # breakdown map; breakdown features are categorical (like dense) but not
  55. # necessarily used to represent data for training
  56. self._breakdown_map = None
  57. # Connect Schema to self.net. That particular instance of schmea will be
  58. # use for generation of the Layers across the network and would be used
  59. # for connection with Readers.
  60. self._input_feature_schema = schema.NewRecord(
  61. self.net,
  62. input_feature_schema
  63. ) if not keep_blobs else input_feature_schema.clone()
  64. self._trainer_extra_schema = schema.NewRecord(
  65. self.net,
  66. trainer_extra_schema
  67. ) if not keep_blobs else trainer_extra_schema.clone()
  68. self._metrics_schema = schema.Struct()
  69. self._preproc_output_schema = None
  70. self._init_global_constants()
  71. self.param_init_net = self.create_init_net('param_init_net')
  72. self._initialize_params = True
  73. self._transfer_learning_blob_name_mappings = None
  74. # additional (hard-coded) diagnose_options to report based on the model
  75. # TODO(xlwang): it's hack!
  76. self.ad_hoc_diagnose_blobs_and_operations = []
  77. self.ad_hoc_plot_blobs = []
  78. self.use_attribution = use_attribution
  79. def clear_output_schema(self):
  80. self._output_schema = None
  81. def set_initialize_params(self, initialize_params):
  82. self._initialize_params = initialize_params
  83. def add_metric_field(self, name, value):
  84. assert name not in self._metrics_schema.fields, (
  85. "Try to add metric field twice: {}".format(name))
  86. self._metrics_schema = self._metrics_schema + schema.Struct(
  87. (name, value)
  88. )
  89. # an empty white_set will skip everything
  90. def filter_metrics_schema(self, white_set):
  91. logger.info("Filter metric schema with white_set {}".format(white_set))
  92. field_names = self._metrics_schema.field_names()
  93. for name in field_names:
  94. if name not in white_set:
  95. self._metrics_schema = self._metrics_schema - schema.Struct((name, schema.Scalar()))
  96. def add_ad_hoc_plot_blob(self, blob, dtype=None):
  97. assert isinstance(
  98. blob, (str, core.BlobReference)
  99. ), "expect type str or BlobReference, but got {}".format(type(blob))
  100. dtype = dtype or (np.float, (1, ))
  101. self.add_metric_field(str(blob), schema.Scalar(dtype, blob))
  102. self.ad_hoc_plot_blobs.append(blob)
  103. @staticmethod
  104. def _get_global_constant_initializer_op(
  105. blob_name, array=None, dtype=None, initializer=None
  106. ):
  107. # to add a global constant to model, one first need to get the
  108. # initializer
  109. if array is not None:
  110. assert initializer is None,\
  111. "Only one from array and initializer should be specified"
  112. if dtype is None:
  113. array = np.array(array)
  114. else:
  115. array = np.array(array, dtype=dtype)
  116. # TODO: make GivenTensor generic
  117. op_name = None
  118. if array.dtype == np.int32:
  119. op_name = 'GivenTensorIntFill'
  120. elif array.dtype == np.int64:
  121. op_name = 'GivenTensorInt64Fill'
  122. elif array.dtype == np.str:
  123. op_name = 'GivenTensorStringFill'
  124. elif array.dtype == np.bool:
  125. op_name = 'GivenTensorBoolFill'
  126. else:
  127. op_name = 'GivenTensorFill'
  128. def initializer(blob_name):
  129. return core.CreateOperator(
  130. op_name, [],
  131. blob_name,
  132. shape=array.shape,
  133. values=array.flatten().tolist()
  134. )
  135. else:
  136. assert initializer is not None
  137. initializer_op = initializer(blob_name)
  138. return initializer_op
  139. def add_global_constant(
  140. self, name, array=None, dtype=None, initializer=None
  141. ):
  142. assert isinstance(name, str), (
  143. 'name should be a string as we are using it as map key')
  144. # This is global namescope for constants. They will be created in all
  145. # init_nets and there should be very few of them.
  146. assert name not in self.global_constants, \
  147. "%s already added in global_constants" % name
  148. blob_name = self.net.NextBlob(name)
  149. self.global_constants[name] = blob_name
  150. initializer_op = LayerModelHelper._get_global_constant_initializer_op(
  151. blob_name, array, dtype, initializer
  152. )
  153. assert blob_name not in self.global_constant_initializers, \
  154. "there is already a initializer op associated with blob %s" % \
  155. blob_name
  156. self.global_constant_initializers[blob_name] = initializer_op
  157. return blob_name
  158. def maybe_add_global_constant(self, name, *args, **kwargs):
  159. # To ad hoc add new global constants without duplication
  160. # if the name was already registered in global_constants, it will not be
  161. # added even if the intended value is different from its original value
  162. if name in self.global_constants:
  163. blob_name = self.global_constants[name]
  164. initializer_op = \
  165. LayerModelHelper._get_global_constant_initializer_op(
  166. blob_name, *args, **kwargs
  167. )
  168. # check if the original initializer is the same as the one intended
  169. # now
  170. assert utils.OpAlmostEqual(
  171. initializer_op,
  172. self.global_constant_initializers[blob_name],
  173. 'debug_info'
  174. ), \
  175. "conflict initializers for global constant %s, " \
  176. "previous %s, now %s" % (
  177. blob_name, str(initializer_op),
  178. str(self.global_constant_initializers[blob_name]))
  179. return blob_name
  180. return self.add_global_constant(name, *args, **kwargs)
  181. def _init_global_constants(self):
  182. self.global_constants = {}
  183. self.global_constant_initializers = {}
  184. self.add_global_constant('ONE', 1.0)
  185. self.add_global_constant('NAN', float("NaN"))
  186. self.add_global_constant('ZERO', 0.0)
  187. self.add_global_constant('ZERO_RANGE', [0, 0], dtype='int32')
  188. def _add_global_constants(self, init_net):
  189. for initializer_op in viewvalues(self.global_constant_initializers):
  190. init_net._net.op.extend([initializer_op])
  191. def create_init_net(self, name):
  192. init_net = core.Net(name)
  193. self._add_global_constants(init_net)
  194. return init_net
  195. def _validate_param_shape(self, param_name, shape):
  196. if param_name not in self._param_to_shape:
  197. return
  198. ref_shape = self._param_to_shape[param_name]
  199. if shape != ref_shape:
  200. raise ValueError(
  201. "Got inconsistent shapes between shared parameters "
  202. "when trying to map a blob in scope {0} to {1}. ref_shape : "
  203. " {2}, shape : {3}".format(
  204. scope.CurrentNameScope(), param_name, ref_shape, shape)
  205. )
  206. def _validate_param_optim(self, param_name, optim):
  207. # there are three possible values for optim:
  208. # 1) None (which will use self._default_optimizer after this layer is instantiated)
  209. # 2) self.NoOptim
  210. # 3) an instance of Optimizer class such as AdagradOptimizer
  211. # this implies this parameter is not shared with any other parameter so far
  212. if param_name not in self.param_to_optim:
  213. return
  214. logger.info("{} shares the same parameter with another parameter. "
  215. "Validating if the same optimizer has been specified for them.".format(
  216. param_name,
  217. ))
  218. ref_optim = self.param_to_optim[param_name]
  219. if optim is None:
  220. assert ref_optim == self._default_optimizer, (
  221. "Optim for {} is None which will fall back to use default_optimizer. "
  222. "However, the optimizer that has been specified for this shared parameter "
  223. "is {} which is different from default_optimizer {}. "
  224. "Please check the optimizers specified for parameters shared "
  225. "with {} and the default_optimizer to ensure the consistency.".format(
  226. param_name, ref_optim, self._default_optimizer, param_name
  227. )
  228. )
  229. elif optim == self.NoOptim:
  230. assert ref_optim == self.NoOptim, (
  231. "Optim for {} is NoOptim. However, the optimizer for the parameters "
  232. "shared with {} is {} which is different from NoOptim. "
  233. "Please check the optimizer specified for other parameters in the "
  234. "shared group to ensure consistency.".format(
  235. param_name, param_name, ref_optim
  236. )
  237. )
  238. elif isinstance(optim, Optimizer):
  239. assert isinstance(ref_optim, Optimizer), (
  240. "Optim for {} is an instance of Optimizer. However, the optimizer "
  241. "for the parameters shared with {} is {} which is not an instance "
  242. "of Optimizer. Please check the optimizer specified for other "
  243. " parameters in the shared group to ensure consistency.".format(
  244. param_name, param_name, ref_optim, optim
  245. )
  246. )
  247. assert type(optim) is type(ref_optim) and optim.attributes == ref_optim.attributes, (
  248. "Optim for {} is an instance of Optimizer. However, the optimizer "
  249. "for the parameters shared with {} is {}. "
  250. "This optimizer either doesn't have the same type as the current optimizer: "
  251. "{} vs {}, or its attributes such as learning rate are different from "
  252. "that of current optimizer which is {} vs {}. "
  253. "Please check the optimizer specified for other parameters in the "
  254. "shared group to ensure consistency.".format(
  255. param_name, param_name, ref_optim, type(optim), type(ref_optim), optim.attributes, ref_optim.attributes
  256. )
  257. )
  258. else:
  259. raise ValueError("optim should be either None, NoOptim, or an instance of Optimizer, Got {} ".format(optim))
  260. def create_param(self, param_name, shape, initializer, optimizer=None,
  261. ps_param=None, regularizer=None):
  262. if isinstance(param_name, core.BlobReference):
  263. param_name = str(param_name)
  264. elif isinstance(param_name, str):
  265. # Parameter name will be equal to current Namescope that got
  266. # resolved with the respect of parameter sharing of the scopes.
  267. param_name = parameter_sharing_context.get_parameter_name(
  268. param_name)
  269. else:
  270. raise ValueError("Unsupported type for param_name")
  271. param_blob = core.BlobReference(param_name)
  272. if len(initializer) == 1:
  273. init_op_args = {}
  274. else:
  275. assert len(initializer) == 2
  276. init_op_args = copy.deepcopy(initializer[1])
  277. if shape is not None:
  278. assert 'shape' not in init_op_args
  279. init_op_args.update({'shape': shape})
  280. initializer_op = None
  281. if self._initialize_params:
  282. initializer_op = core.CreateOperator(
  283. initializer[0],
  284. [],
  285. param_blob,
  286. **init_op_args
  287. )
  288. param = layers.LayerParameter(
  289. parameter=param_blob,
  290. initializer=initializer_op,
  291. optimizer=optimizer,
  292. ps_param=ps_param,
  293. regularizer=regularizer
  294. )
  295. self._validate_param_shape(param_name, shape)
  296. self._validate_param_optim(param_name, optimizer)
  297. self._param_to_shape[param_name] = shape
  298. return param
  299. def next_layer_name(self, prefix):
  300. base_name = core.ScopedName(prefix)
  301. name = base_name
  302. index = 0
  303. while name in self._layer_names:
  304. name = base_name + '_auto_' + str(index)
  305. index += 1
  306. self._layer_names.add(name)
  307. return name
  308. def add_layer(self, layer):
  309. self._layers.append(layer)
  310. for param in layer.get_parameters():
  311. assert isinstance(param.parameter, core.BlobReference)
  312. self.param_to_optim[str(param.parameter)] = \
  313. param.optimizer or self.default_optimizer
  314. self.params.append(param.parameter)
  315. if isinstance(param, layers.LayerParameter):
  316. logger.info("Add parameter regularizer {0}".format(param.parameter))
  317. self.param_to_reg[param.parameter] = param.regularizer
  318. elif isinstance(param, ParameterInfo):
  319. # TODO:
  320. # Currently, LSTM and RNNcells, which use ModelHelper instead of
  321. # LayerModelHelper as super class, are called in pooling_methods
  322. # In ModelHelper, regularization is not supported in create_param
  323. # We will unify the way of create_param of ModelHelper and
  324. # LayerModelHelper in the future.
  325. logger.info('regularization is unsupported for ParameterInfo object')
  326. else:
  327. raise ValueError(
  328. 'unknown object type besides ParameterInfo and LayerParameter: {}'
  329. .format(param)
  330. )
  331. # The primary value of adding everything to self.net - generation of the
  332. # operators right away, i.e. if error happens it'll be detected
  333. # immediately. Other than this - create_x_net should be called.
  334. layer.add_operators(self.net, self.param_init_net)
  335. return layer.output_schema
  336. def get_parameter_blobs(self):
  337. param_blobs = []
  338. for layer in self._layers:
  339. for param in layer.get_parameters():
  340. param_blobs.append(param.parameter)
  341. return param_blobs
  342. def add_post_grad_net_modifiers(self, modifier):
  343. assert modifier not in self._post_grad_net_modifiers,\
  344. "{0} is already in {1}".format(modifier, self._post_grad_net_modifiers)
  345. assert isinstance(modifier, NetModifier),\
  346. "{} has to be a NetModifier instance".format(modifier)
  347. self._post_grad_net_modifiers.append(modifier)
  348. def add_final_net_modifiers(self, modifier):
  349. assert modifier not in self._final_net_modifiers,\
  350. "{0} is already in {1}".format(modifier, self._final_net_modifiers)
  351. assert isinstance(modifier, NetModifier),\
  352. "{} has to be a NetModifier instance".format(modifier)
  353. self._final_net_modifiers.append(modifier)
  354. @property
  355. def seed(self):
  356. return self._seed
  357. @property
  358. def sequence_seed(self):
  359. return self._sequence_seed
  360. def store_seed(self, seed, sequence_seed=True):
  361. # Store seed config that will be applied to each op in the net.
  362. self._seed = seed
  363. # If sequence_seed is True, the i-th op has rand_seed=`seed + i`
  364. self._sequence_seed = sequence_seed
  365. def apply_seed(self, net):
  366. if self._seed:
  367. net.set_rand_seed(self._seed, self._sequence_seed)
  368. @property
  369. def default_optimizer(self):
  370. return self._default_optimizer
  371. @default_optimizer.setter
  372. def default_optimizer(self, optimizer):
  373. self._default_optimizer = optimizer
  374. @property
  375. def input_feature_schema(self):
  376. return self._input_feature_schema
  377. @property
  378. def trainer_extra_schema(self):
  379. return self._trainer_extra_schema
  380. @property
  381. def metrics_schema(self):
  382. """
  383. Returns the schema that represents model output that should be used for
  384. metric reporting.
  385. During the training/evaluation this schema will be appended to the
  386. schema that represents model output.
  387. """
  388. return self._metrics_schema
  389. @property
  390. def output_schema(self):
  391. assert self._output_schema is not None
  392. return self._output_schema
  393. @output_schema.setter
  394. def output_schema(self, schema):
  395. assert self._output_schema is None
  396. self._output_schema = schema
  397. @property
  398. def preproc_output_schema(self):
  399. assert self._preproc_output_schema is not None
  400. return self._preproc_output_schema
  401. @preproc_output_schema.setter
  402. def preproc_output_schema(self, schema):
  403. assert self._preproc_output_schema is None
  404. self._preproc_output_schema = schema
  405. @property
  406. def prediction(self):
  407. assert self._prediction, "model prediction is empty"
  408. return self._prediction
  409. def add_prediction(self, prediction, weight=1.0):
  410. assert prediction is not None, "Added prediction should not be None"
  411. self._prediction.append((prediction, weight))
  412. @property
  413. def transfer_learning_blob_name_mappings(self):
  414. return self._transfer_learning_blob_name_mappings
  415. @transfer_learning_blob_name_mappings.setter
  416. def transfer_learning_blob_name_mappings(self, blob_name_mappings):
  417. assert blob_name_mappings is not None, "Transfer learning blob name mappings should not be None"
  418. self._transfer_learning_blob_name_mappings = blob_name_mappings
  419. @property
  420. def loss(self):
  421. assert self._loss is not None
  422. return self._loss
  423. @loss.setter
  424. def loss(self, loss):
  425. assert self._loss is None
  426. self._loss = loss
  427. def has_loss(self):
  428. return self._loss is not None
  429. def add_loss(self, loss, name='unnamed'):
  430. assert loss is not None, "Added loss should not be None"
  431. assert isinstance(loss, schema.Scalar) or isinstance(
  432. loss, schema.Struct
  433. ), "Added loss should be a scalar or a struct"
  434. if self._loss is None:
  435. self._loss = schema.Struct((name, loss))
  436. else:
  437. # loss could've been set through model.loss directly which could be
  438. # a scalar
  439. if isinstance(self._loss, schema.Scalar):
  440. self._loss = schema.Struct(('unnamed', self._loss))
  441. prefix_base = name + '_auto_'
  442. index = 0
  443. prefix = name
  444. while prefix in self._loss:
  445. prefix = prefix_base + str(index)
  446. index += 1
  447. loss_struct = schema.Struct((prefix, loss))
  448. self._loss = self._loss + loss_struct
  449. def add_output_schema(self, name, value):
  450. assert value is not None, \
  451. 'Added output schema {} should not be None'.format(name)
  452. assert isinstance(value, schema.Scalar) or \
  453. isinstance(value, schema.Struct), \
  454. 'Added output schema {} should be a scalar or a struct.\n\
  455. Now it is {}.'.format(name, type(value))
  456. if self._output_schema is None: # be the first field
  457. self._output_schema = schema.Struct((name, value))
  458. else: # merge with other fields
  459. assert name not in self._output_schema.fields, \
  460. 'Output Schema Field {} already exists'.format(name)
  461. self._output_schema = \
  462. self._output_schema + schema.Struct((name, value))
  463. def add_trainer_extra_schema(self, trainer_extra_schema):
  464. trainer_extra_record = schema.NewRecord(self.net, trainer_extra_schema)
  465. self._trainer_extra_schema += trainer_extra_record
  466. def __getattr__(self, layer):
  467. def is_functional_layer(layer):
  468. if core.IsOperator(layer):
  469. return True
  470. elif layer.startswith('FunctionalLayer'):
  471. return True
  472. else:
  473. return False
  474. def resolve_functional_layer(layer):
  475. if core.IsOperator(layer):
  476. return layer
  477. elif layer.startswith('FunctionalLayer'):
  478. return layer[len('FunctionalLayer'):]
  479. else:
  480. raise ValueError(
  481. '%s cannot be resolved as functional layer' % layer
  482. )
  483. if layer.startswith('__'):
  484. raise AttributeError(layer)
  485. # TODO(amalevich): Add add support for ifbpy inline documentation
  486. if layers.layer_exists(layer):
  487. def wrapper(*args, **kwargs):
  488. new_layer = layers.create_layer(layer, self, *args, **kwargs)
  489. if kwargs.get("output_to_metrics", False):
  490. new_layer.export_output_for_metrics()
  491. if kwargs.get("params_to_metrics", False):
  492. new_layer.export_params_for_metrics()
  493. return self.add_layer(new_layer)
  494. return wrapper
  495. elif is_functional_layer(layer):
  496. # TODO(xlwang): Desginated layer shadows the usage of an op as a
  497. # single layer. To enforce using an op (e.g. Split) as functional
  498. # layer, one can call 'model.FunctionalLayerSplit'
  499. layer = resolve_functional_layer(layer)
  500. def wrapper(*args, **kwargs):
  501. def apply_operator(net, in_record, out_record, **kwargs):
  502. # TODO(amalevich): Switch to net.operator as soon as it gets
  503. # landed
  504. net.__getattr__(layer)(in_record.field_blobs(),
  505. out_record.field_blobs(),
  506. **kwargs)
  507. if 'name' not in kwargs:
  508. kwargs['name'] = layer
  509. new_layer = layers.create_layer(
  510. 'Functional',
  511. self, *args, function=apply_operator,
  512. **kwargs
  513. )
  514. if kwargs.get("output_to_metrics", False):
  515. new_layer.export_output_for_metrics()
  516. if kwargs.get("params_to_metrics", False):
  517. new_layer.export_params_for_metrics()
  518. return self.add_layer(new_layer)
  519. return wrapper
  520. else:
  521. # this needs to be an AttributeError to fit hasattr semantics
  522. raise AttributeError(
  523. "Trying to create non-registered layer: {}".format(layer))
  524. @property
  525. def layers(self):
  526. return self._layers
  527. def apply_regularizers_on_loss(
  528. self,
  529. train_net,
  530. train_init_net,
  531. blob_to_device=None,
  532. ):
  533. logger.info("apply regularizer on loss")
  534. for param, regularizer in viewitems(self.param_to_reg):
  535. if regularizer is None:
  536. continue
  537. logger.info("add regularizer {0} for param {1} to loss".format(regularizer, param))
  538. assert isinstance(regularizer, Regularizer)
  539. added_loss_blob = regularizer(train_net, train_init_net, param, grad=None,
  540. by=RegularizationBy.ON_LOSS)
  541. logger.info(added_loss_blob)
  542. if added_loss_blob is not None:
  543. self.add_loss(
  544. schema.Scalar(blob=added_loss_blob),
  545. str(added_loss_blob)
  546. )
  547. def apply_regularizers_after_optimizer(
  548. self,
  549. train_net,
  550. train_init_net,
  551. grad_map,
  552. blob_to_device=None,
  553. ):
  554. logger.info("apply regularizer after optimizer")
  555. CPU = muji.OnCPU()
  556. # if given, blob_to_device is a map from blob to device_option
  557. blob_to_device = blob_to_device or {}
  558. for param, regularizer in viewitems(self.param_to_reg):
  559. if regularizer is None:
  560. continue
  561. assert isinstance(regularizer, Regularizer)
  562. logger.info("add regularizer {0} for param {1} to optimizer".format(regularizer, param))
  563. device = get_param_device(
  564. param,
  565. grad_map.get(str(param)),
  566. param_to_device=blob_to_device,
  567. default_device=CPU,
  568. )
  569. with core.DeviceScope(device):
  570. regularizer(
  571. train_net, train_init_net, param, grad=grad_map.get(str(param)),
  572. by=RegularizationBy.AFTER_OPTIMIZER
  573. )
  574. def apply_post_grad_net_modifiers(
  575. self,
  576. trainer_net,
  577. trainer_init_net,
  578. grad_map,
  579. blob_to_device=None,
  580. modify_output_record=False,
  581. ):
  582. param_grad_map = {param: grad_map[param]
  583. for param in self.param_to_optim.keys() if param in grad_map}
  584. for modifier in self._post_grad_net_modifiers:
  585. modifier(trainer_net, trainer_init_net, param_grad_map,
  586. blob_to_device=blob_to_device,
  587. modify_output_record=modify_output_record)
  588. def apply_final_net_modifiers(
  589. self,
  590. trainer_net,
  591. trainer_init_net,
  592. grad_map,
  593. blob_to_device=None,
  594. modify_output_record=False,
  595. ):
  596. for modifier in self._final_net_modifiers:
  597. modifier(trainer_net, trainer_init_net, grad_map,
  598. blob_to_device=blob_to_device,
  599. modify_output_record=modify_output_record)
  600. def apply_optimizers(
  601. self,
  602. train_net,
  603. train_init_net,
  604. grad_map,
  605. blob_to_device=None,
  606. ):
  607. CPU = muji.OnCPU()
  608. # if given, blob_to_device is a map from blob to device_option
  609. blob_to_device = blob_to_device or {}
  610. for param, optimizer in viewitems(self.param_to_optim):
  611. assert optimizer is not None, \
  612. "default optimizer must have been set in add_layer"
  613. # note that not all params has gradient and thus we sent None if
  614. # gradient does not exists
  615. device = get_param_device(
  616. param,
  617. grad_map.get(str(param)),
  618. param_to_device=blob_to_device,
  619. default_device=CPU,
  620. )
  621. if device is not None:
  622. # extra info is not applicable for optimizers
  623. del device.extra_info[:]
  624. with core.DeviceScope(device):
  625. optimizer(
  626. train_net, train_init_net, param, grad_map.get(str(param)))
  627. def _GetOne(self):
  628. return self.global_constants['ONE']
  629. # An optimizer which allows us to do NO optimization
  630. def NoOptim(self, *args, **kwargs):
  631. pass
  632. @property
  633. def breakdown_map(self):
  634. return self._breakdown_map
  635. @breakdown_map.setter
  636. def breakdown_map(self, breakdown_map):
  637. # TODO(xlwang): provide more rich feature information in breakdown_map;
  638. # and change the assertion accordingly
  639. assert isinstance(breakdown_map, dict)
  640. assert all(isinstance(k, str) for k in breakdown_map)
  641. assert sorted(breakdown_map.values()) == list(range(len(breakdown_map)))
  642. self._breakdown_map = breakdown_map