rnn_cell.py 67 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981
  1. ## @package rnn_cell
  2. # Module caffe2.python.rnn_cell
  3. import functools
  4. import inspect
  5. import logging
  6. import numpy as np
  7. import random
  8. from future.utils import viewkeys
  9. from caffe2.proto import caffe2_pb2
  10. from caffe2.python.attention import (
  11. apply_dot_attention,
  12. apply_recurrent_attention,
  13. apply_regular_attention,
  14. apply_soft_coverage_attention,
  15. AttentionType,
  16. )
  17. from caffe2.python import core, recurrent, workspace, brew, scope, utils
  18. from caffe2.python.modeling.parameter_sharing import ParameterSharing
  19. from caffe2.python.modeling.parameter_info import ParameterTags
  20. from caffe2.python.modeling.initializers import Initializer
  21. from caffe2.python.model_helper import ModelHelper
  22. def _RectifyName(blob_reference_or_name):
  23. if blob_reference_or_name is None:
  24. return None
  25. if isinstance(blob_reference_or_name, str):
  26. return core.ScopedBlobReference(blob_reference_or_name)
  27. if not isinstance(blob_reference_or_name, core.BlobReference):
  28. raise Exception("Unknown blob reference type")
  29. return blob_reference_or_name
  30. def _RectifyNames(blob_references_or_names):
  31. if blob_references_or_names is None:
  32. return None
  33. return [_RectifyName(i) for i in blob_references_or_names]
  34. class RNNCell(object):
  35. '''
  36. Base class for writing recurrent / stateful operations.
  37. One needs to implement 2 methods: apply_override
  38. and get_state_names_override.
  39. As a result base class will provice apply_over_sequence method, which
  40. allows you to apply recurrent operations over a sequence of any length.
  41. As optional you could add input and output preparation steps by overriding
  42. corresponding methods.
  43. '''
  44. def __init__(self, name=None, forward_only=False, initializer=None):
  45. self.name = name
  46. self.recompute_blobs = []
  47. self.forward_only = forward_only
  48. self._initializer = initializer
  49. @property
  50. def initializer(self):
  51. return self._initializer
  52. @initializer.setter
  53. def initializer(self, value):
  54. self._initializer = value
  55. def scope(self, name):
  56. return self.name + '/' + name if self.name is not None else name
  57. def apply_over_sequence(
  58. self,
  59. model,
  60. inputs,
  61. seq_lengths=None,
  62. initial_states=None,
  63. outputs_with_grads=None,
  64. ):
  65. if initial_states is None:
  66. with scope.NameScope(self.name):
  67. if self.initializer is None:
  68. raise Exception("Either initial states "
  69. "or initializer have to be set")
  70. initial_states = self.initializer.create_states(model)
  71. preprocessed_inputs = self.prepare_input(model, inputs)
  72. step_model = ModelHelper(name=self.name, param_model=model)
  73. input_t, timestep = step_model.net.AddScopedExternalInputs(
  74. 'input_t',
  75. 'timestep',
  76. )
  77. utils.raiseIfNotEqual(
  78. len(initial_states), len(self.get_state_names()),
  79. "Number of initial state values provided doesn't match the number "
  80. "of states"
  81. )
  82. states_prev = step_model.net.AddScopedExternalInputs(*[
  83. s + '_prev' for s in self.get_state_names()
  84. ])
  85. states = self._apply(
  86. model=step_model,
  87. input_t=input_t,
  88. seq_lengths=seq_lengths,
  89. states=states_prev,
  90. timestep=timestep,
  91. )
  92. external_outputs = set(step_model.net.Proto().external_output)
  93. for state in states:
  94. if state not in external_outputs:
  95. step_model.net.AddExternalOutput(state)
  96. if outputs_with_grads is None:
  97. outputs_with_grads = [self.get_output_state_index() * 2]
  98. # states_for_all_steps consists of combination of
  99. # states gather for all steps and final states. It looks like this:
  100. # (state_1_all, state_1_final, state_2_all, state_2_final, ...)
  101. states_for_all_steps = recurrent.recurrent_net(
  102. net=model.net,
  103. cell_net=step_model.net,
  104. inputs=[(input_t, preprocessed_inputs)],
  105. initial_cell_inputs=list(zip(states_prev, initial_states)),
  106. links=dict(zip(states_prev, states)),
  107. timestep=timestep,
  108. scope=self.name,
  109. forward_only=self.forward_only,
  110. outputs_with_grads=outputs_with_grads,
  111. recompute_blobs_on_backward=self.recompute_blobs,
  112. )
  113. output = self._prepare_output_sequence(
  114. model,
  115. states_for_all_steps,
  116. )
  117. return output, states_for_all_steps
  118. def apply(self, model, input_t, seq_lengths, states, timestep):
  119. input_t = self.prepare_input(model, input_t)
  120. states = self._apply(
  121. model, input_t, seq_lengths, states, timestep)
  122. output = self._prepare_output(model, states)
  123. return output, states
  124. def _apply(
  125. self,
  126. model, input_t, seq_lengths, states, timestep, extra_inputs=None
  127. ):
  128. '''
  129. This method uses apply_override provided by a custom cell.
  130. On the top it takes care of applying self.scope() to all the outputs.
  131. While all the inputs stay within the scope this function was called
  132. from.
  133. '''
  134. args = self._rectify_apply_inputs(
  135. input_t, seq_lengths, states, timestep, extra_inputs)
  136. with core.NameScope(self.name):
  137. return self.apply_override(model, *args)
  138. def _rectify_apply_inputs(
  139. self, input_t, seq_lengths, states, timestep, extra_inputs):
  140. '''
  141. Before applying a scope we make sure that all external blob names
  142. are converted to blob reference. So further scoping doesn't affect them
  143. '''
  144. input_t, seq_lengths, timestep = _RectifyNames(
  145. [input_t, seq_lengths, timestep])
  146. states = _RectifyNames(states)
  147. if extra_inputs:
  148. extra_input_names, extra_input_sizes = zip(*extra_inputs)
  149. extra_inputs = _RectifyNames(extra_input_names)
  150. extra_inputs = zip(extra_input_names, extra_input_sizes)
  151. arg_names = inspect.getargspec(self.apply_override).args
  152. rectified = [input_t, seq_lengths, states, timestep]
  153. if 'extra_inputs' in arg_names:
  154. rectified.append(extra_inputs)
  155. return rectified
  156. def apply_override(
  157. self,
  158. model, input_t, seq_lengths, timestep, extra_inputs=None,
  159. ):
  160. '''
  161. A single step of a recurrent network to be implemented by each custom
  162. RNNCell.
  163. model: ModelHelper object new operators would be added to
  164. input_t: singlse input with shape (1, batch_size, input_dim)
  165. seq_lengths: blob containing sequence lengths which would be passed to
  166. LSTMUnit operator
  167. states: previous recurrent states
  168. timestep: current recurrent iteration. Could be used together with
  169. seq_lengths in order to determine, if some shorter sequences
  170. in the batch have already ended.
  171. extra_inputs: list of tuples (input, dim). specifies additional input
  172. which is not subject to prepare_input(). (useful when a cell is a
  173. component of a larger recurrent structure, e.g., attention)
  174. '''
  175. raise NotImplementedError('Abstract method')
  176. def prepare_input(self, model, input_blob):
  177. '''
  178. If some operations in _apply method depend only on the input,
  179. not on recurrent states, they could be computed in advance.
  180. model: ModelHelper object new operators would be added to
  181. input_blob: either the whole input sequence with shape
  182. (sequence_length, batch_size, input_dim) or a single input with shape
  183. (1, batch_size, input_dim).
  184. '''
  185. return input_blob
  186. def get_output_state_index(self):
  187. '''
  188. Return index into state list of the "primary" step-wise output.
  189. '''
  190. return 0
  191. def get_state_names(self):
  192. '''
  193. Returns recurrent state names with self.name scoping applied
  194. '''
  195. return [self.scope(name) for name in self.get_state_names_override()]
  196. def get_state_names_override(self):
  197. '''
  198. Override this function in your custom cell.
  199. It should return the names of the recurrent states.
  200. It's required by apply_over_sequence method in order to allocate
  201. recurrent states for all steps with meaningful names.
  202. '''
  203. raise NotImplementedError('Abstract method')
  204. def get_output_dim(self):
  205. '''
  206. Specifies the dimension (number of units) of stepwise output.
  207. '''
  208. raise NotImplementedError('Abstract method')
  209. def _prepare_output(self, model, states):
  210. '''
  211. Allows arbitrary post-processing of primary output.
  212. '''
  213. return states[self.get_output_state_index()]
  214. def _prepare_output_sequence(self, model, state_outputs):
  215. '''
  216. Allows arbitrary post-processing of primary sequence output.
  217. (Note that state_outputs alternates between full-sequence and final
  218. output for each state, thus the index multiplier 2.)
  219. '''
  220. output_sequence_index = 2 * self.get_output_state_index()
  221. return state_outputs[output_sequence_index]
  222. class LSTMInitializer(object):
  223. def __init__(self, hidden_size):
  224. self.hidden_size = hidden_size
  225. def create_states(self, model):
  226. return [
  227. model.create_param(
  228. param_name='initial_hidden_state',
  229. initializer=Initializer(operator_name='ConstantFill',
  230. value=0.0),
  231. shape=[self.hidden_size],
  232. ),
  233. model.create_param(
  234. param_name='initial_cell_state',
  235. initializer=Initializer(operator_name='ConstantFill',
  236. value=0.0),
  237. shape=[self.hidden_size],
  238. )
  239. ]
  240. # based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
  241. class BasicRNNCell(RNNCell):
  242. def __init__(
  243. self,
  244. input_size,
  245. hidden_size,
  246. forget_bias,
  247. memory_optimization,
  248. drop_states=False,
  249. initializer=None,
  250. activation=None,
  251. **kwargs
  252. ):
  253. super(BasicRNNCell, self).__init__(**kwargs)
  254. self.drop_states = drop_states
  255. self.input_size = input_size
  256. self.hidden_size = hidden_size
  257. self.activation = activation
  258. if self.activation not in ['relu', 'tanh']:
  259. raise RuntimeError(
  260. 'BasicRNNCell with unknown activation function (%s)'
  261. % self.activation)
  262. def apply_override(
  263. self,
  264. model,
  265. input_t,
  266. seq_lengths,
  267. states,
  268. timestep,
  269. extra_inputs=None,
  270. ):
  271. hidden_t_prev = states[0]
  272. gates_t = brew.fc(
  273. model,
  274. hidden_t_prev,
  275. 'gates_t',
  276. dim_in=self.hidden_size,
  277. dim_out=self.hidden_size,
  278. axis=2,
  279. )
  280. brew.sum(model, [gates_t, input_t], gates_t)
  281. if self.activation == 'tanh':
  282. hidden_t = model.net.Tanh(gates_t, 'hidden_t')
  283. elif self.activation == 'relu':
  284. hidden_t = model.net.Relu(gates_t, 'hidden_t')
  285. else:
  286. raise RuntimeError(
  287. 'BasicRNNCell with unknown activation function (%s)'
  288. % self.activation)
  289. if seq_lengths is not None:
  290. # TODO If this codepath becomes popular, it may be worth
  291. # taking a look at optimizing it - for now a simple
  292. # implementation is used to round out compatibility with
  293. # ONNX.
  294. timestep = model.net.CopyFromCPUInput(
  295. timestep, 'timestep_gpu')
  296. valid_b = model.net.GT(
  297. [seq_lengths, timestep], 'valid_b', broadcast=1)
  298. invalid_b = model.net.LE(
  299. [seq_lengths, timestep], 'invalid_b', broadcast=1)
  300. valid = model.net.Cast(valid_b, 'valid', to='float')
  301. invalid = model.net.Cast(invalid_b, 'invalid', to='float')
  302. hidden_valid = model.net.Mul(
  303. [hidden_t, valid],
  304. 'hidden_valid',
  305. broadcast=1,
  306. axis=1,
  307. )
  308. if self.drop_states:
  309. hidden_t = hidden_valid
  310. else:
  311. hidden_invalid = model.net.Mul(
  312. [hidden_t_prev, invalid],
  313. 'hidden_invalid',
  314. broadcast=1, axis=1)
  315. hidden_t = model.net.Add(
  316. [hidden_valid, hidden_invalid], hidden_t)
  317. return (hidden_t,)
  318. def prepare_input(self, model, input_blob):
  319. return brew.fc(
  320. model,
  321. input_blob,
  322. self.scope('i2h'),
  323. dim_in=self.input_size,
  324. dim_out=self.hidden_size,
  325. axis=2,
  326. )
  327. def get_state_names(self):
  328. return (self.scope('hidden_t'),)
  329. def get_output_dim(self):
  330. return self.hidden_size
  331. class LSTMCell(RNNCell):
  332. def __init__(
  333. self,
  334. input_size,
  335. hidden_size,
  336. forget_bias,
  337. memory_optimization,
  338. drop_states=False,
  339. initializer=None,
  340. **kwargs
  341. ):
  342. super(LSTMCell, self).__init__(initializer=initializer, **kwargs)
  343. self.initializer = initializer or LSTMInitializer(
  344. hidden_size=hidden_size)
  345. self.input_size = input_size
  346. self.hidden_size = hidden_size
  347. self.forget_bias = float(forget_bias)
  348. self.memory_optimization = memory_optimization
  349. self.drop_states = drop_states
  350. self.gates_size = 4 * self.hidden_size
  351. def apply_override(
  352. self,
  353. model,
  354. input_t,
  355. seq_lengths,
  356. states,
  357. timestep,
  358. extra_inputs=None,
  359. ):
  360. hidden_t_prev, cell_t_prev = states
  361. fc_input = hidden_t_prev
  362. fc_input_dim = self.hidden_size
  363. if extra_inputs is not None:
  364. extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
  365. fc_input = brew.concat(
  366. model,
  367. [hidden_t_prev] + list(extra_input_blobs),
  368. 'gates_concatenated_input_t',
  369. axis=2,
  370. )
  371. fc_input_dim += sum(extra_input_sizes)
  372. gates_t = brew.fc(
  373. model,
  374. fc_input,
  375. 'gates_t',
  376. dim_in=fc_input_dim,
  377. dim_out=self.gates_size,
  378. axis=2,
  379. )
  380. brew.sum(model, [gates_t, input_t], gates_t)
  381. if seq_lengths is not None:
  382. inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
  383. else:
  384. inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
  385. hidden_t, cell_t = model.net.LSTMUnit(
  386. inputs,
  387. ['hidden_state', 'cell_state'],
  388. forget_bias=self.forget_bias,
  389. drop_states=self.drop_states,
  390. sequence_lengths=(seq_lengths is not None),
  391. )
  392. model.net.AddExternalOutputs(hidden_t, cell_t)
  393. if self.memory_optimization:
  394. self.recompute_blobs = [gates_t]
  395. return hidden_t, cell_t
  396. def get_input_params(self):
  397. return {
  398. 'weights': self.scope('i2h') + '_w',
  399. 'biases': self.scope('i2h') + '_b',
  400. }
  401. def get_recurrent_params(self):
  402. return {
  403. 'weights': self.scope('gates_t') + '_w',
  404. 'biases': self.scope('gates_t') + '_b',
  405. }
  406. def prepare_input(self, model, input_blob):
  407. return brew.fc(
  408. model,
  409. input_blob,
  410. self.scope('i2h'),
  411. dim_in=self.input_size,
  412. dim_out=self.gates_size,
  413. axis=2,
  414. )
  415. def get_state_names_override(self):
  416. return ['hidden_t', 'cell_t']
  417. def get_output_dim(self):
  418. return self.hidden_size
  419. class LayerNormLSTMCell(RNNCell):
  420. def __init__(
  421. self,
  422. input_size,
  423. hidden_size,
  424. forget_bias,
  425. memory_optimization,
  426. drop_states=False,
  427. initializer=None,
  428. **kwargs
  429. ):
  430. super(LayerNormLSTMCell, self).__init__(
  431. initializer=initializer, **kwargs
  432. )
  433. self.initializer = initializer or LSTMInitializer(
  434. hidden_size=hidden_size
  435. )
  436. self.input_size = input_size
  437. self.hidden_size = hidden_size
  438. self.forget_bias = float(forget_bias)
  439. self.memory_optimization = memory_optimization
  440. self.drop_states = drop_states
  441. self.gates_size = 4 * self.hidden_size
  442. def _apply(
  443. self,
  444. model,
  445. input_t,
  446. seq_lengths,
  447. states,
  448. timestep,
  449. extra_inputs=None,
  450. ):
  451. hidden_t_prev, cell_t_prev = states
  452. fc_input = hidden_t_prev
  453. fc_input_dim = self.hidden_size
  454. if extra_inputs is not None:
  455. extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
  456. fc_input = brew.concat(
  457. model,
  458. [hidden_t_prev] + list(extra_input_blobs),
  459. self.scope('gates_concatenated_input_t'),
  460. axis=2,
  461. )
  462. fc_input_dim += sum(extra_input_sizes)
  463. gates_t = brew.fc(
  464. model,
  465. fc_input,
  466. self.scope('gates_t'),
  467. dim_in=fc_input_dim,
  468. dim_out=self.gates_size,
  469. axis=2,
  470. )
  471. brew.sum(model, [gates_t, input_t], gates_t)
  472. # brew.layer_norm call is only difference from LSTMCell
  473. gates_t, _, _ = brew.layer_norm(
  474. model,
  475. self.scope('gates_t'),
  476. self.scope('gates_t_norm'),
  477. dim_in=self.gates_size,
  478. axis=-1,
  479. )
  480. hidden_t, cell_t = model.net.LSTMUnit(
  481. [
  482. hidden_t_prev,
  483. cell_t_prev,
  484. gates_t,
  485. seq_lengths,
  486. timestep,
  487. ],
  488. self.get_state_names(),
  489. forget_bias=self.forget_bias,
  490. drop_states=self.drop_states,
  491. )
  492. model.net.AddExternalOutputs(hidden_t, cell_t)
  493. if self.memory_optimization:
  494. self.recompute_blobs = [gates_t]
  495. return hidden_t, cell_t
  496. def get_input_params(self):
  497. return {
  498. 'weights': self.scope('i2h') + '_w',
  499. 'biases': self.scope('i2h') + '_b',
  500. }
  501. def prepare_input(self, model, input_blob):
  502. return brew.fc(
  503. model,
  504. input_blob,
  505. self.scope('i2h'),
  506. dim_in=self.input_size,
  507. dim_out=self.gates_size,
  508. axis=2,
  509. )
  510. def get_state_names(self):
  511. return (self.scope('hidden_t'), self.scope('cell_t'))
  512. class MILSTMCell(LSTMCell):
  513. def _apply(
  514. self,
  515. model,
  516. input_t,
  517. seq_lengths,
  518. states,
  519. timestep,
  520. extra_inputs=None,
  521. ):
  522. hidden_t_prev, cell_t_prev = states
  523. fc_input = hidden_t_prev
  524. fc_input_dim = self.hidden_size
  525. if extra_inputs is not None:
  526. extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
  527. fc_input = brew.concat(
  528. model,
  529. [hidden_t_prev] + list(extra_input_blobs),
  530. self.scope('gates_concatenated_input_t'),
  531. axis=2,
  532. )
  533. fc_input_dim += sum(extra_input_sizes)
  534. prev_t = brew.fc(
  535. model,
  536. fc_input,
  537. self.scope('prev_t'),
  538. dim_in=fc_input_dim,
  539. dim_out=self.gates_size,
  540. axis=2,
  541. )
  542. # defining initializers for MI parameters
  543. alpha = model.create_param(
  544. self.scope('alpha'),
  545. shape=[self.gates_size],
  546. initializer=Initializer('ConstantFill', value=1.0),
  547. )
  548. beta_h = model.create_param(
  549. self.scope('beta1'),
  550. shape=[self.gates_size],
  551. initializer=Initializer('ConstantFill', value=1.0),
  552. )
  553. beta_i = model.create_param(
  554. self.scope('beta2'),
  555. shape=[self.gates_size],
  556. initializer=Initializer('ConstantFill', value=1.0),
  557. )
  558. b = model.create_param(
  559. self.scope('b'),
  560. shape=[self.gates_size],
  561. initializer=Initializer('ConstantFill', value=0.0),
  562. )
  563. # alpha * input_t + beta_h
  564. # Shape: [1, batch_size, 4 * hidden_size]
  565. alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
  566. [input_t, alpha, beta_h],
  567. self.scope('alpha_by_input_t_plus_beta_h'),
  568. axis=2,
  569. )
  570. # (alpha * input_t + beta_h) * prev_t =
  571. # alpha * input_t * prev_t + beta_h * prev_t
  572. # Shape: [1, batch_size, 4 * hidden_size]
  573. alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
  574. [alpha_by_input_t_plus_beta_h, prev_t],
  575. self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
  576. )
  577. # beta_i * input_t + b
  578. # Shape: [1, batch_size, 4 * hidden_size]
  579. beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
  580. [input_t, beta_i, b],
  581. self.scope('beta_i_by_input_t_plus_b'),
  582. axis=2,
  583. )
  584. # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
  585. # Shape: [1, batch_size, 4 * hidden_size]
  586. gates_t = brew.sum(
  587. model,
  588. [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
  589. self.scope('gates_t')
  590. )
  591. hidden_t, cell_t = model.net.LSTMUnit(
  592. [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
  593. [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
  594. forget_bias=self.forget_bias,
  595. drop_states=self.drop_states,
  596. )
  597. model.net.AddExternalOutputs(
  598. cell_t,
  599. hidden_t,
  600. )
  601. if self.memory_optimization:
  602. self.recompute_blobs = [gates_t]
  603. return hidden_t, cell_t
  604. class LayerNormMILSTMCell(LSTMCell):
  605. def _apply(
  606. self,
  607. model,
  608. input_t,
  609. seq_lengths,
  610. states,
  611. timestep,
  612. extra_inputs=None,
  613. ):
  614. hidden_t_prev, cell_t_prev = states
  615. fc_input = hidden_t_prev
  616. fc_input_dim = self.hidden_size
  617. if extra_inputs is not None:
  618. extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
  619. fc_input = brew.concat(
  620. model,
  621. [hidden_t_prev] + list(extra_input_blobs),
  622. self.scope('gates_concatenated_input_t'),
  623. axis=2,
  624. )
  625. fc_input_dim += sum(extra_input_sizes)
  626. prev_t = brew.fc(
  627. model,
  628. fc_input,
  629. self.scope('prev_t'),
  630. dim_in=fc_input_dim,
  631. dim_out=self.gates_size,
  632. axis=2,
  633. )
  634. # defining initializers for MI parameters
  635. alpha = model.create_param(
  636. self.scope('alpha'),
  637. shape=[self.gates_size],
  638. initializer=Initializer('ConstantFill', value=1.0),
  639. )
  640. beta_h = model.create_param(
  641. self.scope('beta1'),
  642. shape=[self.gates_size],
  643. initializer=Initializer('ConstantFill', value=1.0),
  644. )
  645. beta_i = model.create_param(
  646. self.scope('beta2'),
  647. shape=[self.gates_size],
  648. initializer=Initializer('ConstantFill', value=1.0),
  649. )
  650. b = model.create_param(
  651. self.scope('b'),
  652. shape=[self.gates_size],
  653. initializer=Initializer('ConstantFill', value=0.0),
  654. )
  655. # alpha * input_t + beta_h
  656. # Shape: [1, batch_size, 4 * hidden_size]
  657. alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
  658. [input_t, alpha, beta_h],
  659. self.scope('alpha_by_input_t_plus_beta_h'),
  660. axis=2,
  661. )
  662. # (alpha * input_t + beta_h) * prev_t =
  663. # alpha * input_t * prev_t + beta_h * prev_t
  664. # Shape: [1, batch_size, 4 * hidden_size]
  665. alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
  666. [alpha_by_input_t_plus_beta_h, prev_t],
  667. self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
  668. )
  669. # beta_i * input_t + b
  670. # Shape: [1, batch_size, 4 * hidden_size]
  671. beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
  672. [input_t, beta_i, b],
  673. self.scope('beta_i_by_input_t_plus_b'),
  674. axis=2,
  675. )
  676. # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
  677. # Shape: [1, batch_size, 4 * hidden_size]
  678. gates_t = brew.sum(
  679. model,
  680. [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
  681. self.scope('gates_t')
  682. )
  683. # brew.layer_norm call is only difference from MILSTMCell._apply
  684. gates_t, _, _ = brew.layer_norm(
  685. model,
  686. self.scope('gates_t'),
  687. self.scope('gates_t_norm'),
  688. dim_in=self.gates_size,
  689. axis=-1,
  690. )
  691. hidden_t, cell_t = model.net.LSTMUnit(
  692. [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
  693. [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
  694. forget_bias=self.forget_bias,
  695. drop_states=self.drop_states,
  696. )
  697. model.net.AddExternalOutputs(
  698. cell_t,
  699. hidden_t,
  700. )
  701. if self.memory_optimization:
  702. self.recompute_blobs = [gates_t]
  703. return hidden_t, cell_t
  704. class DropoutCell(RNNCell):
  705. '''
  706. Wraps arbitrary RNNCell, applying dropout to its output (but not to the
  707. recurrent connection for the corresponding state).
  708. '''
  709. def __init__(
  710. self,
  711. internal_cell,
  712. dropout_ratio=None,
  713. use_cudnn=False,
  714. **kwargs
  715. ):
  716. self.internal_cell = internal_cell
  717. self.dropout_ratio = dropout_ratio
  718. assert 'is_test' in kwargs, "Argument 'is_test' is required"
  719. self.is_test = kwargs.pop('is_test')
  720. self.use_cudnn = use_cudnn
  721. super(DropoutCell, self).__init__(**kwargs)
  722. self.prepare_input = internal_cell.prepare_input
  723. self.get_output_state_index = internal_cell.get_output_state_index
  724. self.get_state_names = internal_cell.get_state_names
  725. self.get_output_dim = internal_cell.get_output_dim
  726. self.mask = 0
  727. def _apply(
  728. self,
  729. model,
  730. input_t,
  731. seq_lengths,
  732. states,
  733. timestep,
  734. extra_inputs=None,
  735. ):
  736. return self.internal_cell._apply(
  737. model,
  738. input_t,
  739. seq_lengths,
  740. states,
  741. timestep,
  742. extra_inputs,
  743. )
  744. def _prepare_output(self, model, states):
  745. output = self.internal_cell._prepare_output(
  746. model,
  747. states,
  748. )
  749. if self.dropout_ratio is not None:
  750. output = self._apply_dropout(model, output)
  751. return output
  752. def _prepare_output_sequence(self, model, state_outputs):
  753. output = self.internal_cell._prepare_output_sequence(
  754. model,
  755. state_outputs,
  756. )
  757. if self.dropout_ratio is not None:
  758. output = self._apply_dropout(model, output)
  759. return output
  760. def _apply_dropout(self, model, output):
  761. if self.dropout_ratio and not self.forward_only:
  762. with core.NameScope(self.name or ''):
  763. output = brew.dropout(
  764. model,
  765. output,
  766. str(output) + '_with_dropout_mask{}'.format(self.mask),
  767. ratio=float(self.dropout_ratio),
  768. is_test=self.is_test,
  769. use_cudnn=self.use_cudnn,
  770. )
  771. self.mask += 1
  772. return output
  773. class MultiRNNCellInitializer(object):
  774. def __init__(self, cells):
  775. self.cells = cells
  776. def create_states(self, model):
  777. states = []
  778. for i, cell in enumerate(self.cells):
  779. if cell.initializer is None:
  780. raise Exception("Either initial states "
  781. "or initializer have to be set")
  782. with core.NameScope("layer_{}".format(i)),\
  783. core.NameScope(cell.name):
  784. states.extend(cell.initializer.create_states(model))
  785. return states
  786. class MultiRNNCell(RNNCell):
  787. '''
  788. Multilayer RNN via the composition of RNNCell instance.
  789. It is the responsibility of calling code to ensure the compatibility
  790. of the successive layers in terms of input/output dimensiality, etc.,
  791. and to ensure that their blobs do not have name conflicts, typically by
  792. creating the cells with names that specify layer number.
  793. Assumes first state (recurrent output) for each layer should be the input
  794. to the next layer.
  795. '''
  796. def __init__(self, cells, residual_output_layers=None, **kwargs):
  797. '''
  798. cells: list of RNNCell instances, from input to output side.
  799. name: string designating network component (for scoping)
  800. residual_output_layers: list of indices of layers whose input will
  801. be added elementwise to their output elementwise. (It is the
  802. responsibility of the client code to ensure shape compatibility.)
  803. Note that layer 0 (zero) cannot have residual output because of the
  804. timing of prepare_input().
  805. forward_only: used to construct inference-only network.
  806. '''
  807. super(MultiRNNCell, self).__init__(**kwargs)
  808. self.cells = cells
  809. if residual_output_layers is None:
  810. self.residual_output_layers = []
  811. else:
  812. self.residual_output_layers = residual_output_layers
  813. output_index_per_layer = []
  814. base_index = 0
  815. for cell in self.cells:
  816. output_index_per_layer.append(
  817. base_index + cell.get_output_state_index(),
  818. )
  819. base_index += len(cell.get_state_names())
  820. self.output_connected_layers = []
  821. self.output_indices = []
  822. for i in range(len(self.cells) - 1):
  823. if (i + 1) in self.residual_output_layers:
  824. self.output_connected_layers.append(i)
  825. self.output_indices.append(output_index_per_layer[i])
  826. else:
  827. self.output_connected_layers = []
  828. self.output_indices = []
  829. self.output_connected_layers.append(len(self.cells) - 1)
  830. self.output_indices.append(output_index_per_layer[-1])
  831. self.state_names = []
  832. for i, cell in enumerate(self.cells):
  833. self.state_names.extend(
  834. map(self.layer_scoper(i), cell.get_state_names())
  835. )
  836. self.initializer = MultiRNNCellInitializer(cells)
  837. def layer_scoper(self, layer_id):
  838. def helper(name):
  839. return "{}/layer_{}/{}".format(self.name, layer_id, name)
  840. return helper
  841. def prepare_input(self, model, input_blob):
  842. input_blob = _RectifyName(input_blob)
  843. with core.NameScope(self.name or ''):
  844. return self.cells[0].prepare_input(model, input_blob)
  845. def _apply(
  846. self,
  847. model,
  848. input_t,
  849. seq_lengths,
  850. states,
  851. timestep,
  852. extra_inputs=None,
  853. ):
  854. '''
  855. Because below we will do scoping across layers, we need
  856. to make sure that string blob names are convereted to BlobReference
  857. objects.
  858. '''
  859. input_t, seq_lengths, states, timestep, extra_inputs = \
  860. self._rectify_apply_inputs(
  861. input_t, seq_lengths, states, timestep, extra_inputs)
  862. states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
  863. assert len(states) == sum(states_per_layer)
  864. next_states = []
  865. states_index = 0
  866. layer_input = input_t
  867. for i, layer_cell in enumerate(self.cells):
  868. # # If cells don't have different names we still
  869. # take care of scoping
  870. with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
  871. num_states = states_per_layer[i]
  872. layer_states = states[states_index:(states_index + num_states)]
  873. states_index += num_states
  874. if i > 0:
  875. prepared_input = layer_cell.prepare_input(
  876. model, layer_input)
  877. else:
  878. prepared_input = layer_input
  879. layer_next_states = layer_cell._apply(
  880. model,
  881. prepared_input,
  882. seq_lengths,
  883. layer_states,
  884. timestep,
  885. extra_inputs=(None if i > 0 else extra_inputs),
  886. )
  887. # Since we're using here non-public method _apply,
  888. # instead of apply, we have to manually extract output
  889. # from states
  890. if i != len(self.cells) - 1:
  891. layer_output = layer_cell._prepare_output(
  892. model,
  893. layer_next_states,
  894. )
  895. if i > 0 and i in self.residual_output_layers:
  896. layer_input = brew.sum(
  897. model,
  898. [layer_output, layer_input],
  899. self.scope('residual_output_{}'.format(i)),
  900. )
  901. else:
  902. layer_input = layer_output
  903. next_states.extend(layer_next_states)
  904. return next_states
  905. def get_state_names(self):
  906. return self.state_names
  907. def get_output_state_index(self):
  908. index = 0
  909. for cell in self.cells[:-1]:
  910. index += len(cell.get_state_names())
  911. index += self.cells[-1].get_output_state_index()
  912. return index
  913. def _prepare_output(self, model, states):
  914. connected_outputs = []
  915. state_index = 0
  916. for i, cell in enumerate(self.cells):
  917. num_states = len(cell.get_state_names())
  918. if i in self.output_connected_layers:
  919. layer_states = states[state_index:state_index + num_states]
  920. layer_output = cell._prepare_output(
  921. model,
  922. layer_states
  923. )
  924. connected_outputs.append(layer_output)
  925. state_index += num_states
  926. if len(connected_outputs) > 1:
  927. output = brew.sum(
  928. model,
  929. connected_outputs,
  930. self.scope('residual_output'),
  931. )
  932. else:
  933. output = connected_outputs[0]
  934. return output
  935. def _prepare_output_sequence(self, model, states):
  936. connected_outputs = []
  937. state_index = 0
  938. for i, cell in enumerate(self.cells):
  939. num_states = 2 * len(cell.get_state_names())
  940. if i in self.output_connected_layers:
  941. layer_states = states[state_index:state_index + num_states]
  942. layer_output = cell._prepare_output_sequence(
  943. model,
  944. layer_states
  945. )
  946. connected_outputs.append(layer_output)
  947. state_index += num_states
  948. if len(connected_outputs) > 1:
  949. output = brew.sum(
  950. model,
  951. connected_outputs,
  952. self.scope('residual_output_sequence'),
  953. )
  954. else:
  955. output = connected_outputs[0]
  956. return output
  957. class AttentionCell(RNNCell):
  958. def __init__(
  959. self,
  960. encoder_output_dim,
  961. encoder_outputs,
  962. encoder_lengths,
  963. decoder_cell,
  964. decoder_state_dim,
  965. attention_type,
  966. weighted_encoder_outputs,
  967. attention_memory_optimization,
  968. **kwargs
  969. ):
  970. super(AttentionCell, self).__init__(**kwargs)
  971. self.encoder_output_dim = encoder_output_dim
  972. self.encoder_outputs = encoder_outputs
  973. self.encoder_lengths = encoder_lengths
  974. self.decoder_cell = decoder_cell
  975. self.decoder_state_dim = decoder_state_dim
  976. self.weighted_encoder_outputs = weighted_encoder_outputs
  977. self.encoder_outputs_transposed = None
  978. assert attention_type in [
  979. AttentionType.Regular,
  980. AttentionType.Recurrent,
  981. AttentionType.Dot,
  982. AttentionType.SoftCoverage,
  983. ]
  984. self.attention_type = attention_type
  985. self.attention_memory_optimization = attention_memory_optimization
  986. def _apply(
  987. self,
  988. model,
  989. input_t,
  990. seq_lengths,
  991. states,
  992. timestep,
  993. extra_inputs=None,
  994. ):
  995. if self.attention_type == AttentionType.SoftCoverage:
  996. decoder_prev_states = states[:-2]
  997. attention_weighted_encoder_context_t_prev = states[-2]
  998. coverage_t_prev = states[-1]
  999. else:
  1000. decoder_prev_states = states[:-1]
  1001. attention_weighted_encoder_context_t_prev = states[-1]
  1002. assert extra_inputs is None
  1003. decoder_states = self.decoder_cell._apply(
  1004. model,
  1005. input_t,
  1006. seq_lengths,
  1007. decoder_prev_states,
  1008. timestep,
  1009. extra_inputs=[(
  1010. attention_weighted_encoder_context_t_prev,
  1011. self.encoder_output_dim,
  1012. )],
  1013. )
  1014. self.hidden_t_intermediate = self.decoder_cell._prepare_output(
  1015. model,
  1016. decoder_states,
  1017. )
  1018. if self.attention_type == AttentionType.Recurrent:
  1019. (
  1020. attention_weighted_encoder_context_t,
  1021. self.attention_weights_3d,
  1022. attention_blobs,
  1023. ) = apply_recurrent_attention(
  1024. model=model,
  1025. encoder_output_dim=self.encoder_output_dim,
  1026. encoder_outputs_transposed=self.encoder_outputs_transposed,
  1027. weighted_encoder_outputs=self.weighted_encoder_outputs,
  1028. decoder_hidden_state_t=self.hidden_t_intermediate,
  1029. decoder_hidden_state_dim=self.decoder_state_dim,
  1030. scope=self.name,
  1031. attention_weighted_encoder_context_t_prev=(
  1032. attention_weighted_encoder_context_t_prev
  1033. ),
  1034. encoder_lengths=self.encoder_lengths,
  1035. )
  1036. elif self.attention_type == AttentionType.Regular:
  1037. (
  1038. attention_weighted_encoder_context_t,
  1039. self.attention_weights_3d,
  1040. attention_blobs,
  1041. ) = apply_regular_attention(
  1042. model=model,
  1043. encoder_output_dim=self.encoder_output_dim,
  1044. encoder_outputs_transposed=self.encoder_outputs_transposed,
  1045. weighted_encoder_outputs=self.weighted_encoder_outputs,
  1046. decoder_hidden_state_t=self.hidden_t_intermediate,
  1047. decoder_hidden_state_dim=self.decoder_state_dim,
  1048. scope=self.name,
  1049. encoder_lengths=self.encoder_lengths,
  1050. )
  1051. elif self.attention_type == AttentionType.Dot:
  1052. (
  1053. attention_weighted_encoder_context_t,
  1054. self.attention_weights_3d,
  1055. attention_blobs,
  1056. ) = apply_dot_attention(
  1057. model=model,
  1058. encoder_output_dim=self.encoder_output_dim,
  1059. encoder_outputs_transposed=self.encoder_outputs_transposed,
  1060. decoder_hidden_state_t=self.hidden_t_intermediate,
  1061. decoder_hidden_state_dim=self.decoder_state_dim,
  1062. scope=self.name,
  1063. encoder_lengths=self.encoder_lengths,
  1064. )
  1065. elif self.attention_type == AttentionType.SoftCoverage:
  1066. (
  1067. attention_weighted_encoder_context_t,
  1068. self.attention_weights_3d,
  1069. attention_blobs,
  1070. coverage_t,
  1071. ) = apply_soft_coverage_attention(
  1072. model=model,
  1073. encoder_output_dim=self.encoder_output_dim,
  1074. encoder_outputs_transposed=self.encoder_outputs_transposed,
  1075. weighted_encoder_outputs=self.weighted_encoder_outputs,
  1076. decoder_hidden_state_t=self.hidden_t_intermediate,
  1077. decoder_hidden_state_dim=self.decoder_state_dim,
  1078. scope=self.name,
  1079. encoder_lengths=self.encoder_lengths,
  1080. coverage_t_prev=coverage_t_prev,
  1081. coverage_weights=self.coverage_weights,
  1082. )
  1083. else:
  1084. raise Exception('Attention type {} not implemented'.format(
  1085. self.attention_type
  1086. ))
  1087. if self.attention_memory_optimization:
  1088. self.recompute_blobs.extend(attention_blobs)
  1089. output = list(decoder_states) + [attention_weighted_encoder_context_t]
  1090. if self.attention_type == AttentionType.SoftCoverage:
  1091. output.append(coverage_t)
  1092. output[self.decoder_cell.get_output_state_index()] = model.Copy(
  1093. output[self.decoder_cell.get_output_state_index()],
  1094. self.scope('hidden_t_external'),
  1095. )
  1096. model.net.AddExternalOutputs(*output)
  1097. return output
  1098. def get_attention_weights(self):
  1099. # [batch_size, encoder_length, 1]
  1100. return self.attention_weights_3d
  1101. def prepare_input(self, model, input_blob):
  1102. if self.encoder_outputs_transposed is None:
  1103. self.encoder_outputs_transposed = brew.transpose(
  1104. model,
  1105. self.encoder_outputs,
  1106. self.scope('encoder_outputs_transposed'),
  1107. axes=[1, 2, 0],
  1108. )
  1109. if (
  1110. self.weighted_encoder_outputs is None and
  1111. self.attention_type != AttentionType.Dot
  1112. ):
  1113. self.weighted_encoder_outputs = brew.fc(
  1114. model,
  1115. self.encoder_outputs,
  1116. self.scope('weighted_encoder_outputs'),
  1117. dim_in=self.encoder_output_dim,
  1118. dim_out=self.encoder_output_dim,
  1119. axis=2,
  1120. )
  1121. return self.decoder_cell.prepare_input(model, input_blob)
  1122. def build_initial_coverage(self, model):
  1123. """
  1124. initial_coverage is always zeros of shape [encoder_length],
  1125. which shape must be determined programmatically dureing network
  1126. computation.
  1127. This method also sets self.coverage_weights, a separate transform
  1128. of encoder_outputs which is used to determine coverage contribution
  1129. tp attention.
  1130. """
  1131. assert self.attention_type == AttentionType.SoftCoverage
  1132. # [encoder_length, batch_size, encoder_output_dim]
  1133. self.coverage_weights = brew.fc(
  1134. model,
  1135. self.encoder_outputs,
  1136. self.scope('coverage_weights'),
  1137. dim_in=self.encoder_output_dim,
  1138. dim_out=self.encoder_output_dim,
  1139. axis=2,
  1140. )
  1141. encoder_length = model.net.Slice(
  1142. model.net.Shape(self.encoder_outputs),
  1143. starts=[0],
  1144. ends=[1],
  1145. )
  1146. if (
  1147. scope.CurrentDeviceScope() is not None and
  1148. core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
  1149. ):
  1150. encoder_length = model.net.CopyGPUToCPU(
  1151. encoder_length,
  1152. 'encoder_length_cpu',
  1153. )
  1154. # total attention weight applied across decoding steps_per_checkpoint
  1155. # shape: [encoder_length]
  1156. initial_coverage = model.net.ConstantFill(
  1157. encoder_length,
  1158. self.scope('initial_coverage'),
  1159. value=0.0,
  1160. input_as_shape=1,
  1161. )
  1162. return initial_coverage
  1163. def get_state_names(self):
  1164. state_names = list(self.decoder_cell.get_state_names())
  1165. state_names[self.get_output_state_index()] = self.scope(
  1166. 'hidden_t_external',
  1167. )
  1168. state_names.append(self.scope('attention_weighted_encoder_context_t'))
  1169. if self.attention_type == AttentionType.SoftCoverage:
  1170. state_names.append(self.scope('coverage_t'))
  1171. return state_names
  1172. def get_output_dim(self):
  1173. return self.decoder_state_dim + self.encoder_output_dim
  1174. def get_output_state_index(self):
  1175. return self.decoder_cell.get_output_state_index()
  1176. def _prepare_output(self, model, states):
  1177. if self.attention_type == AttentionType.SoftCoverage:
  1178. attention_context = states[-2]
  1179. else:
  1180. attention_context = states[-1]
  1181. with core.NameScope(self.name or ''):
  1182. output = brew.concat(
  1183. model,
  1184. [self.hidden_t_intermediate, attention_context],
  1185. 'states_and_context_combination',
  1186. axis=2,
  1187. )
  1188. return output
  1189. def _prepare_output_sequence(self, model, state_outputs):
  1190. if self.attention_type == AttentionType.SoftCoverage:
  1191. decoder_state_outputs = state_outputs[:-4]
  1192. else:
  1193. decoder_state_outputs = state_outputs[:-2]
  1194. decoder_output = self.decoder_cell._prepare_output_sequence(
  1195. model,
  1196. decoder_state_outputs,
  1197. )
  1198. if self.attention_type == AttentionType.SoftCoverage:
  1199. attention_context_index = 2 * (len(self.get_state_names()) - 2)
  1200. else:
  1201. attention_context_index = 2 * (len(self.get_state_names()) - 1)
  1202. with core.NameScope(self.name or ''):
  1203. output = brew.concat(
  1204. model,
  1205. [
  1206. decoder_output,
  1207. state_outputs[attention_context_index],
  1208. ],
  1209. 'states_and_context_combination',
  1210. axis=2,
  1211. )
  1212. return output
  1213. class LSTMWithAttentionCell(AttentionCell):
  1214. def __init__(
  1215. self,
  1216. encoder_output_dim,
  1217. encoder_outputs,
  1218. encoder_lengths,
  1219. decoder_input_dim,
  1220. decoder_state_dim,
  1221. name,
  1222. attention_type,
  1223. weighted_encoder_outputs,
  1224. forget_bias,
  1225. lstm_memory_optimization,
  1226. attention_memory_optimization,
  1227. forward_only=False,
  1228. ):
  1229. decoder_cell = LSTMCell(
  1230. input_size=decoder_input_dim,
  1231. hidden_size=decoder_state_dim,
  1232. forget_bias=forget_bias,
  1233. memory_optimization=lstm_memory_optimization,
  1234. name='{}/decoder'.format(name),
  1235. forward_only=False,
  1236. drop_states=False,
  1237. )
  1238. super(LSTMWithAttentionCell, self).__init__(
  1239. encoder_output_dim=encoder_output_dim,
  1240. encoder_outputs=encoder_outputs,
  1241. encoder_lengths=encoder_lengths,
  1242. decoder_cell=decoder_cell,
  1243. decoder_state_dim=decoder_state_dim,
  1244. name=name,
  1245. attention_type=attention_type,
  1246. weighted_encoder_outputs=weighted_encoder_outputs,
  1247. attention_memory_optimization=attention_memory_optimization,
  1248. forward_only=forward_only,
  1249. )
  1250. class MILSTMWithAttentionCell(AttentionCell):
  1251. def __init__(
  1252. self,
  1253. encoder_output_dim,
  1254. encoder_outputs,
  1255. decoder_input_dim,
  1256. decoder_state_dim,
  1257. name,
  1258. attention_type,
  1259. weighted_encoder_outputs,
  1260. forget_bias,
  1261. lstm_memory_optimization,
  1262. attention_memory_optimization,
  1263. forward_only=False,
  1264. ):
  1265. decoder_cell = MILSTMCell(
  1266. input_size=decoder_input_dim,
  1267. hidden_size=decoder_state_dim,
  1268. forget_bias=forget_bias,
  1269. memory_optimization=lstm_memory_optimization,
  1270. name='{}/decoder'.format(name),
  1271. forward_only=False,
  1272. drop_states=False,
  1273. )
  1274. super(MILSTMWithAttentionCell, self).__init__(
  1275. encoder_output_dim=encoder_output_dim,
  1276. encoder_outputs=encoder_outputs,
  1277. decoder_cell=decoder_cell,
  1278. decoder_state_dim=decoder_state_dim,
  1279. name=name,
  1280. attention_type=attention_type,
  1281. weighted_encoder_outputs=weighted_encoder_outputs,
  1282. attention_memory_optimization=attention_memory_optimization,
  1283. forward_only=forward_only,
  1284. )
  1285. def _LSTM(
  1286. cell_class,
  1287. model,
  1288. input_blob,
  1289. seq_lengths,
  1290. initial_states,
  1291. dim_in,
  1292. dim_out,
  1293. scope=None,
  1294. outputs_with_grads=(0,),
  1295. return_params=False,
  1296. memory_optimization=False,
  1297. forget_bias=0.0,
  1298. forward_only=False,
  1299. drop_states=False,
  1300. return_last_layer_only=True,
  1301. static_rnn_unroll_size=None,
  1302. **cell_kwargs
  1303. ):
  1304. '''
  1305. Adds a standard LSTM recurrent network operator to a model.
  1306. cell_class: LSTMCell or compatible subclass
  1307. model: ModelHelper object new operators would be added to
  1308. input_blob: the input sequence in a format T x N x D
  1309. where T is sequence size, N - batch size and D - input dimension
  1310. seq_lengths: blob containing sequence lengths which would be passed to
  1311. LSTMUnit operator
  1312. initial_states: a list of (2 * num_layers) blobs representing the initial
  1313. hidden and cell states of each layer. If this argument is None,
  1314. these states will be added to the model as network parameters.
  1315. dim_in: input dimension
  1316. dim_out: number of units per LSTM layer
  1317. (use int for single-layer LSTM, list of ints for multi-layer)
  1318. outputs_with_grads : position indices of output blobs for LAST LAYER which
  1319. will receive external error gradient during backpropagation.
  1320. These outputs are: (h_all, h_last, c_all, c_last)
  1321. return_params: if True, will return a dictionary of parameters of the LSTM
  1322. memory_optimization: if enabled, the LSTM step is recomputed on backward
  1323. step so that we don't need to store forward activations for each
  1324. timestep. Saves memory with cost of computation.
  1325. forget_bias: forget gate bias (default 0.0)
  1326. forward_only: whether to create a backward pass
  1327. drop_states: drop invalid states, passed through to LSTMUnit operator
  1328. return_last_layer_only: only return outputs from final layer
  1329. (so that length of results does depend on number of layers)
  1330. static_rnn_unroll_size: if not None, we will use static RNN which is
  1331. unrolled into Caffe2 graph. The size of the unroll is the value of
  1332. this parameter.
  1333. '''
  1334. if type(dim_out) is not list and type(dim_out) is not tuple:
  1335. dim_out = [dim_out]
  1336. num_layers = len(dim_out)
  1337. cells = []
  1338. for i in range(num_layers):
  1339. cell = cell_class(
  1340. input_size=(dim_in if i == 0 else dim_out[i - 1]),
  1341. hidden_size=dim_out[i],
  1342. forget_bias=forget_bias,
  1343. memory_optimization=memory_optimization,
  1344. name=scope if num_layers == 1 else None,
  1345. forward_only=forward_only,
  1346. drop_states=drop_states,
  1347. **cell_kwargs
  1348. )
  1349. cells.append(cell)
  1350. cell = MultiRNNCell(
  1351. cells,
  1352. name=scope,
  1353. forward_only=forward_only,
  1354. ) if num_layers > 1 else cells[0]
  1355. cell = (
  1356. cell if static_rnn_unroll_size is None
  1357. else UnrolledCell(cell, static_rnn_unroll_size))
  1358. # outputs_with_grads argument indexes into final layer
  1359. outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
  1360. _, result = cell.apply_over_sequence(
  1361. model=model,
  1362. inputs=input_blob,
  1363. seq_lengths=seq_lengths,
  1364. initial_states=initial_states,
  1365. outputs_with_grads=outputs_with_grads,
  1366. )
  1367. if return_last_layer_only:
  1368. result = result[4 * (num_layers - 1):]
  1369. if return_params:
  1370. result = list(result) + [{
  1371. 'input': cell.get_input_params(),
  1372. 'recurrent': cell.get_recurrent_params(),
  1373. }]
  1374. return tuple(result)
  1375. LSTM = functools.partial(_LSTM, LSTMCell)
  1376. BasicRNN = functools.partial(_LSTM, BasicRNNCell)
  1377. MILSTM = functools.partial(_LSTM, MILSTMCell)
  1378. LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
  1379. LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
  1380. class UnrolledCell(RNNCell):
  1381. def __init__(self, cell, T):
  1382. self.T = T
  1383. self.cell = cell
  1384. def apply_over_sequence(
  1385. self,
  1386. model,
  1387. inputs,
  1388. seq_lengths,
  1389. initial_states,
  1390. outputs_with_grads=None,
  1391. ):
  1392. inputs = self.cell.prepare_input(model, inputs)
  1393. # Now they are blob references - outputs of splitting the input sequence
  1394. split_inputs = model.net.Split(
  1395. inputs,
  1396. [str(inputs) + "_timestep_{}".format(i)
  1397. for i in range(self.T)],
  1398. axis=0)
  1399. if self.T == 1:
  1400. split_inputs = [split_inputs]
  1401. states = initial_states
  1402. all_states = []
  1403. for t in range(0, self.T):
  1404. scope_name = "timestep_{}".format(t)
  1405. # Parameters of all timesteps are shared
  1406. with ParameterSharing({scope_name: ''}),\
  1407. scope.NameScope(scope_name):
  1408. timestep = model.param_init_net.ConstantFill(
  1409. [], "timestep", value=t, shape=[1],
  1410. dtype=core.DataType.INT32,
  1411. device_option=core.DeviceOption(caffe2_pb2.CPU))
  1412. states = self.cell._apply(
  1413. model=model,
  1414. input_t=split_inputs[t],
  1415. seq_lengths=seq_lengths,
  1416. states=states,
  1417. timestep=timestep,
  1418. )
  1419. all_states.append(states)
  1420. all_states = zip(*all_states)
  1421. all_states = [
  1422. model.net.Concat(
  1423. list(full_output),
  1424. [
  1425. str(full_output[0])[len("timestep_0/"):] + "_concat",
  1426. str(full_output[0])[len("timestep_0/"):] + "_concat_info"
  1427. ],
  1428. axis=0)[0]
  1429. for full_output in all_states
  1430. ]
  1431. # Interleave the state values similar to
  1432. #
  1433. # x = [1, 3, 5]
  1434. # y = [2, 4, 6]
  1435. # z = [val for pair in zip(x, y) for val in pair]
  1436. # # z is [1, 2, 3, 4, 5, 6]
  1437. #
  1438. # and returns it as outputs
  1439. outputs = tuple(
  1440. state for state_pair in zip(all_states, states) for state in state_pair
  1441. )
  1442. outputs_without_grad = set(range(len(outputs))) - set(
  1443. outputs_with_grads)
  1444. for i in outputs_without_grad:
  1445. model.net.ZeroGradient(outputs[i], [])
  1446. logging.debug("Added 0 gradients for blobs:",
  1447. [outputs[i] for i in outputs_without_grad])
  1448. final_output = self.cell._prepare_output_sequence(model, outputs)
  1449. return final_output, outputs
  1450. def GetLSTMParamNames():
  1451. weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
  1452. bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
  1453. return {'weights': weight_params, 'biases': bias_params}
  1454. def InitFromLSTMParams(lstm_pblobs, param_values):
  1455. '''
  1456. Set the parameters of LSTM based on predefined values
  1457. '''
  1458. weight_params = GetLSTMParamNames()['weights']
  1459. bias_params = GetLSTMParamNames()['biases']
  1460. for input_type in viewkeys(param_values):
  1461. weight_values = [
  1462. param_values[input_type][w].flatten()
  1463. for w in weight_params
  1464. ]
  1465. wmat = np.array([])
  1466. for w in weight_values:
  1467. wmat = np.append(wmat, w)
  1468. bias_values = [
  1469. param_values[input_type][b].flatten()
  1470. for b in bias_params
  1471. ]
  1472. bm = np.array([])
  1473. for b in bias_values:
  1474. bm = np.append(bm, b)
  1475. weights_blob = lstm_pblobs[input_type]['weights']
  1476. bias_blob = lstm_pblobs[input_type]['biases']
  1477. cur_weight = workspace.FetchBlob(weights_blob)
  1478. cur_biases = workspace.FetchBlob(bias_blob)
  1479. workspace.FeedBlob(
  1480. weights_blob,
  1481. wmat.reshape(cur_weight.shape).astype(np.float32))
  1482. workspace.FeedBlob(
  1483. bias_blob,
  1484. bm.reshape(cur_biases.shape).astype(np.float32))
  1485. def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
  1486. scope, recurrent_params=None, input_params=None,
  1487. num_layers=1, return_params=False):
  1488. '''
  1489. CuDNN version of LSTM for GPUs.
  1490. input_blob Blob containing the input. Will need to be available
  1491. when param_init_net is run, because the sequence lengths
  1492. and batch sizes will be inferred from the size of this
  1493. blob.
  1494. initial_states tuple of (hidden_init, cell_init) blobs
  1495. dim_in input dimensions
  1496. dim_out output/hidden dimension
  1497. scope namescope to apply
  1498. recurrent_params dict of blobs containing values for recurrent
  1499. gate weights, biases (if None, use random init values)
  1500. See GetLSTMParamNames() for format.
  1501. input_params dict of blobs containing values for input
  1502. gate weights, biases (if None, use random init values)
  1503. See GetLSTMParamNames() for format.
  1504. num_layers number of LSTM layers
  1505. return_params if True, returns (param_extract_net, param_mapping)
  1506. where param_extract_net is a net that when run, will
  1507. populate the blobs specified in param_mapping with the
  1508. current gate weights and biases (input/recurrent).
  1509. Useful for assigning the values back to non-cuDNN
  1510. LSTM.
  1511. '''
  1512. with core.NameScope(scope):
  1513. weight_params = GetLSTMParamNames()['weights']
  1514. bias_params = GetLSTMParamNames()['biases']
  1515. input_weight_size = dim_out * dim_in
  1516. upper_layer_input_weight_size = dim_out * dim_out
  1517. recurrent_weight_size = dim_out * dim_out
  1518. input_bias_size = dim_out
  1519. recurrent_bias_size = dim_out
  1520. def init(layer, pname, input_type):
  1521. input_weight_size_for_layer = input_weight_size if layer == 0 else \
  1522. upper_layer_input_weight_size
  1523. if pname in weight_params:
  1524. sz = input_weight_size_for_layer if input_type == 'input' \
  1525. else recurrent_weight_size
  1526. elif pname in bias_params:
  1527. sz = input_bias_size if input_type == 'input' \
  1528. else recurrent_bias_size
  1529. else:
  1530. assert False, "unknown parameter type {}".format(pname)
  1531. return model.param_init_net.UniformFill(
  1532. [],
  1533. "lstm_init_{}_{}_{}".format(input_type, pname, layer),
  1534. shape=[sz])
  1535. # Multiply by 4 since we have 4 gates per LSTM unit
  1536. first_layer_sz = input_weight_size + recurrent_weight_size + \
  1537. input_bias_size + recurrent_bias_size
  1538. upper_layer_sz = upper_layer_input_weight_size + \
  1539. recurrent_weight_size + input_bias_size + \
  1540. recurrent_bias_size
  1541. total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
  1542. weights = model.create_param(
  1543. 'lstm_weight',
  1544. shape=[total_sz],
  1545. initializer=Initializer('UniformFill'),
  1546. tags=ParameterTags.WEIGHT,
  1547. )
  1548. lstm_args = {
  1549. 'hidden_size': dim_out,
  1550. 'rnn_mode': 'lstm',
  1551. 'bidirectional': 0, # TODO
  1552. 'dropout': 1.0, # TODO
  1553. 'input_mode': 'linear', # TODO
  1554. 'num_layers': num_layers,
  1555. 'engine': 'CUDNN'
  1556. }
  1557. param_extract_net = core.Net("lstm_param_extractor")
  1558. param_extract_net.AddExternalInputs([input_blob, weights])
  1559. param_extract_mapping = {}
  1560. # Populate the weights-blob from blobs containing parameters for
  1561. # the individual components of the LSTM, such as forget/input gate
  1562. # weights and bises. Also, create a special param_extract_net that
  1563. # can be used to grab those individual params from the black-box
  1564. # weights blob. These results can be then fed to InitFromLSTMParams()
  1565. for input_type in ['input', 'recurrent']:
  1566. param_extract_mapping[input_type] = {}
  1567. p = recurrent_params if input_type == 'recurrent' else input_params
  1568. if p is None:
  1569. p = {}
  1570. for pname in weight_params + bias_params:
  1571. for j in range(0, num_layers):
  1572. values = p[pname] if pname in p else init(j, pname, input_type)
  1573. model.param_init_net.RecurrentParamSet(
  1574. [input_blob, weights, values],
  1575. weights,
  1576. layer=j,
  1577. input_type=input_type,
  1578. param_type=pname,
  1579. **lstm_args
  1580. )
  1581. if pname not in param_extract_mapping[input_type]:
  1582. param_extract_mapping[input_type][pname] = {}
  1583. b = param_extract_net.RecurrentParamGet(
  1584. [input_blob, weights],
  1585. ["lstm_{}_{}_{}".format(input_type, pname, j)],
  1586. layer=j,
  1587. input_type=input_type,
  1588. param_type=pname,
  1589. **lstm_args
  1590. )
  1591. param_extract_mapping[input_type][pname][j] = b
  1592. (hidden_input_blob, cell_input_blob) = initial_states
  1593. output, hidden_output, cell_output, rnn_scratch, dropout_states = \
  1594. model.net.Recurrent(
  1595. [input_blob, hidden_input_blob, cell_input_blob, weights],
  1596. ["lstm_output", "lstm_hidden_output", "lstm_cell_output",
  1597. "lstm_rnn_scratch", "lstm_dropout_states"],
  1598. seed=random.randint(0, 100000), # TODO: dropout seed
  1599. **lstm_args
  1600. )
  1601. model.net.AddExternalOutputs(
  1602. hidden_output, cell_output, rnn_scratch, dropout_states)
  1603. if return_params:
  1604. param_extract = param_extract_net, param_extract_mapping
  1605. return output, hidden_output, cell_output, param_extract
  1606. else:
  1607. return output, hidden_output, cell_output
  1608. def LSTMWithAttention(
  1609. model,
  1610. decoder_inputs,
  1611. decoder_input_lengths,
  1612. initial_decoder_hidden_state,
  1613. initial_decoder_cell_state,
  1614. initial_attention_weighted_encoder_context,
  1615. encoder_output_dim,
  1616. encoder_outputs,
  1617. encoder_lengths,
  1618. decoder_input_dim,
  1619. decoder_state_dim,
  1620. scope,
  1621. attention_type=AttentionType.Regular,
  1622. outputs_with_grads=(0, 4),
  1623. weighted_encoder_outputs=None,
  1624. lstm_memory_optimization=False,
  1625. attention_memory_optimization=False,
  1626. forget_bias=0.0,
  1627. forward_only=False,
  1628. ):
  1629. '''
  1630. Adds a LSTM with attention mechanism to a model.
  1631. The implementation is based on https://arxiv.org/abs/1409.0473, with
  1632. a small difference in the order
  1633. how we compute new attention context and new hidden state, similarly to
  1634. https://arxiv.org/abs/1508.04025.
  1635. The model uses encoder-decoder naming conventions,
  1636. where the decoder is the sequence the op is iterating over,
  1637. while computing the attention context over the encoder.
  1638. model: ModelHelper object new operators would be added to
  1639. decoder_inputs: the input sequence in a format T x N x D
  1640. where T is sequence size, N - batch size and D - input dimension
  1641. decoder_input_lengths: blob containing sequence lengths
  1642. which would be passed to LSTMUnit operator
  1643. initial_decoder_hidden_state: initial hidden state of LSTM
  1644. initial_decoder_cell_state: initial cell state of LSTM
  1645. initial_attention_weighted_encoder_context: initial attention context
  1646. encoder_output_dim: dimension of encoder outputs
  1647. encoder_outputs: the sequence, on which we compute the attention context
  1648. at every iteration
  1649. encoder_lengths: a tensor with lengths of each encoder sequence in batch
  1650. (may be None, meaning all encoder sequences are of same length)
  1651. decoder_input_dim: input dimension (last dimension on decoder_inputs)
  1652. decoder_state_dim: size of hidden states of LSTM
  1653. attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
  1654. Determines which type of attention mechanism to use.
  1655. outputs_with_grads : position indices of output blobs which will receive
  1656. external error gradient during backpropagation
  1657. weighted_encoder_outputs: encoder outputs to be used to compute attention
  1658. weights. In the basic case it's just linear transformation of
  1659. encoder outputs (that the default, when weighted_encoder_outputs is None).
  1660. However, it can be something more complicated - like a separate
  1661. encoder network (for example, in case of convolutional encoder)
  1662. lstm_memory_optimization: recompute LSTM activations on backward pass, so
  1663. we don't need to store their values in forward passes
  1664. attention_memory_optimization: recompute attention for backward pass
  1665. forward_only: whether to create only forward pass
  1666. '''
  1667. cell = LSTMWithAttentionCell(
  1668. encoder_output_dim=encoder_output_dim,
  1669. encoder_outputs=encoder_outputs,
  1670. encoder_lengths=encoder_lengths,
  1671. decoder_input_dim=decoder_input_dim,
  1672. decoder_state_dim=decoder_state_dim,
  1673. name=scope,
  1674. attention_type=attention_type,
  1675. weighted_encoder_outputs=weighted_encoder_outputs,
  1676. forget_bias=forget_bias,
  1677. lstm_memory_optimization=lstm_memory_optimization,
  1678. attention_memory_optimization=attention_memory_optimization,
  1679. forward_only=forward_only,
  1680. )
  1681. initial_states = [
  1682. initial_decoder_hidden_state,
  1683. initial_decoder_cell_state,
  1684. initial_attention_weighted_encoder_context,
  1685. ]
  1686. if attention_type == AttentionType.SoftCoverage:
  1687. initial_states.append(cell.build_initial_coverage(model))
  1688. _, result = cell.apply_over_sequence(
  1689. model=model,
  1690. inputs=decoder_inputs,
  1691. seq_lengths=decoder_input_lengths,
  1692. initial_states=initial_states,
  1693. outputs_with_grads=outputs_with_grads,
  1694. )
  1695. return result
  1696. def _layered_LSTM(
  1697. model, input_blob, seq_lengths, initial_states,
  1698. dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
  1699. memory_optimization=False, forget_bias=0.0, forward_only=False,
  1700. drop_states=False, create_lstm=None):
  1701. params = locals() # leave it as a first line to grab all params
  1702. params.pop('create_lstm')
  1703. if not isinstance(dim_out, list):
  1704. return create_lstm(**params)
  1705. elif len(dim_out) == 1:
  1706. params['dim_out'] = dim_out[0]
  1707. return create_lstm(**params)
  1708. assert len(dim_out) != 0, "dim_out list can't be empty"
  1709. assert return_params is False, "return_params not supported for layering"
  1710. for i, output_dim in enumerate(dim_out):
  1711. params.update({
  1712. 'dim_out': output_dim
  1713. })
  1714. output, last_output, all_states, last_state = create_lstm(**params)
  1715. params.update({
  1716. 'input_blob': output,
  1717. 'dim_in': output_dim,
  1718. 'initial_states': (last_output, last_state),
  1719. 'scope': scope + '_layer_{}'.format(i + 1)
  1720. })
  1721. return output, last_output, all_states, last_state
  1722. layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)