| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981 |
- ## @package rnn_cell
- # Module caffe2.python.rnn_cell
- import functools
- import inspect
- import logging
- import numpy as np
- import random
- from future.utils import viewkeys
- from caffe2.proto import caffe2_pb2
- from caffe2.python.attention import (
- apply_dot_attention,
- apply_recurrent_attention,
- apply_regular_attention,
- apply_soft_coverage_attention,
- AttentionType,
- )
- from caffe2.python import core, recurrent, workspace, brew, scope, utils
- from caffe2.python.modeling.parameter_sharing import ParameterSharing
- from caffe2.python.modeling.parameter_info import ParameterTags
- from caffe2.python.modeling.initializers import Initializer
- from caffe2.python.model_helper import ModelHelper
- def _RectifyName(blob_reference_or_name):
- if blob_reference_or_name is None:
- return None
- if isinstance(blob_reference_or_name, str):
- return core.ScopedBlobReference(blob_reference_or_name)
- if not isinstance(blob_reference_or_name, core.BlobReference):
- raise Exception("Unknown blob reference type")
- return blob_reference_or_name
- def _RectifyNames(blob_references_or_names):
- if blob_references_or_names is None:
- return None
- return [_RectifyName(i) for i in blob_references_or_names]
- class RNNCell(object):
- '''
- Base class for writing recurrent / stateful operations.
- One needs to implement 2 methods: apply_override
- and get_state_names_override.
- As a result base class will provice apply_over_sequence method, which
- allows you to apply recurrent operations over a sequence of any length.
- As optional you could add input and output preparation steps by overriding
- corresponding methods.
- '''
- def __init__(self, name=None, forward_only=False, initializer=None):
- self.name = name
- self.recompute_blobs = []
- self.forward_only = forward_only
- self._initializer = initializer
- @property
- def initializer(self):
- return self._initializer
- @initializer.setter
- def initializer(self, value):
- self._initializer = value
- def scope(self, name):
- return self.name + '/' + name if self.name is not None else name
- def apply_over_sequence(
- self,
- model,
- inputs,
- seq_lengths=None,
- initial_states=None,
- outputs_with_grads=None,
- ):
- if initial_states is None:
- with scope.NameScope(self.name):
- if self.initializer is None:
- raise Exception("Either initial states "
- "or initializer have to be set")
- initial_states = self.initializer.create_states(model)
- preprocessed_inputs = self.prepare_input(model, inputs)
- step_model = ModelHelper(name=self.name, param_model=model)
- input_t, timestep = step_model.net.AddScopedExternalInputs(
- 'input_t',
- 'timestep',
- )
- utils.raiseIfNotEqual(
- len(initial_states), len(self.get_state_names()),
- "Number of initial state values provided doesn't match the number "
- "of states"
- )
- states_prev = step_model.net.AddScopedExternalInputs(*[
- s + '_prev' for s in self.get_state_names()
- ])
- states = self._apply(
- model=step_model,
- input_t=input_t,
- seq_lengths=seq_lengths,
- states=states_prev,
- timestep=timestep,
- )
- external_outputs = set(step_model.net.Proto().external_output)
- for state in states:
- if state not in external_outputs:
- step_model.net.AddExternalOutput(state)
- if outputs_with_grads is None:
- outputs_with_grads = [self.get_output_state_index() * 2]
- # states_for_all_steps consists of combination of
- # states gather for all steps and final states. It looks like this:
- # (state_1_all, state_1_final, state_2_all, state_2_final, ...)
- states_for_all_steps = recurrent.recurrent_net(
- net=model.net,
- cell_net=step_model.net,
- inputs=[(input_t, preprocessed_inputs)],
- initial_cell_inputs=list(zip(states_prev, initial_states)),
- links=dict(zip(states_prev, states)),
- timestep=timestep,
- scope=self.name,
- forward_only=self.forward_only,
- outputs_with_grads=outputs_with_grads,
- recompute_blobs_on_backward=self.recompute_blobs,
- )
- output = self._prepare_output_sequence(
- model,
- states_for_all_steps,
- )
- return output, states_for_all_steps
- def apply(self, model, input_t, seq_lengths, states, timestep):
- input_t = self.prepare_input(model, input_t)
- states = self._apply(
- model, input_t, seq_lengths, states, timestep)
- output = self._prepare_output(model, states)
- return output, states
- def _apply(
- self,
- model, input_t, seq_lengths, states, timestep, extra_inputs=None
- ):
- '''
- This method uses apply_override provided by a custom cell.
- On the top it takes care of applying self.scope() to all the outputs.
- While all the inputs stay within the scope this function was called
- from.
- '''
- args = self._rectify_apply_inputs(
- input_t, seq_lengths, states, timestep, extra_inputs)
- with core.NameScope(self.name):
- return self.apply_override(model, *args)
- def _rectify_apply_inputs(
- self, input_t, seq_lengths, states, timestep, extra_inputs):
- '''
- Before applying a scope we make sure that all external blob names
- are converted to blob reference. So further scoping doesn't affect them
- '''
- input_t, seq_lengths, timestep = _RectifyNames(
- [input_t, seq_lengths, timestep])
- states = _RectifyNames(states)
- if extra_inputs:
- extra_input_names, extra_input_sizes = zip(*extra_inputs)
- extra_inputs = _RectifyNames(extra_input_names)
- extra_inputs = zip(extra_input_names, extra_input_sizes)
- arg_names = inspect.getargspec(self.apply_override).args
- rectified = [input_t, seq_lengths, states, timestep]
- if 'extra_inputs' in arg_names:
- rectified.append(extra_inputs)
- return rectified
- def apply_override(
- self,
- model, input_t, seq_lengths, timestep, extra_inputs=None,
- ):
- '''
- A single step of a recurrent network to be implemented by each custom
- RNNCell.
- model: ModelHelper object new operators would be added to
- input_t: singlse input with shape (1, batch_size, input_dim)
- seq_lengths: blob containing sequence lengths which would be passed to
- LSTMUnit operator
- states: previous recurrent states
- timestep: current recurrent iteration. Could be used together with
- seq_lengths in order to determine, if some shorter sequences
- in the batch have already ended.
- extra_inputs: list of tuples (input, dim). specifies additional input
- which is not subject to prepare_input(). (useful when a cell is a
- component of a larger recurrent structure, e.g., attention)
- '''
- raise NotImplementedError('Abstract method')
- def prepare_input(self, model, input_blob):
- '''
- If some operations in _apply method depend only on the input,
- not on recurrent states, they could be computed in advance.
- model: ModelHelper object new operators would be added to
- input_blob: either the whole input sequence with shape
- (sequence_length, batch_size, input_dim) or a single input with shape
- (1, batch_size, input_dim).
- '''
- return input_blob
- def get_output_state_index(self):
- '''
- Return index into state list of the "primary" step-wise output.
- '''
- return 0
- def get_state_names(self):
- '''
- Returns recurrent state names with self.name scoping applied
- '''
- return [self.scope(name) for name in self.get_state_names_override()]
- def get_state_names_override(self):
- '''
- Override this function in your custom cell.
- It should return the names of the recurrent states.
- It's required by apply_over_sequence method in order to allocate
- recurrent states for all steps with meaningful names.
- '''
- raise NotImplementedError('Abstract method')
- def get_output_dim(self):
- '''
- Specifies the dimension (number of units) of stepwise output.
- '''
- raise NotImplementedError('Abstract method')
- def _prepare_output(self, model, states):
- '''
- Allows arbitrary post-processing of primary output.
- '''
- return states[self.get_output_state_index()]
- def _prepare_output_sequence(self, model, state_outputs):
- '''
- Allows arbitrary post-processing of primary sequence output.
- (Note that state_outputs alternates between full-sequence and final
- output for each state, thus the index multiplier 2.)
- '''
- output_sequence_index = 2 * self.get_output_state_index()
- return state_outputs[output_sequence_index]
- class LSTMInitializer(object):
- def __init__(self, hidden_size):
- self.hidden_size = hidden_size
- def create_states(self, model):
- return [
- model.create_param(
- param_name='initial_hidden_state',
- initializer=Initializer(operator_name='ConstantFill',
- value=0.0),
- shape=[self.hidden_size],
- ),
- model.create_param(
- param_name='initial_cell_state',
- initializer=Initializer(operator_name='ConstantFill',
- value=0.0),
- shape=[self.hidden_size],
- )
- ]
- # based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
- class BasicRNNCell(RNNCell):
- def __init__(
- self,
- input_size,
- hidden_size,
- forget_bias,
- memory_optimization,
- drop_states=False,
- initializer=None,
- activation=None,
- **kwargs
- ):
- super(BasicRNNCell, self).__init__(**kwargs)
- self.drop_states = drop_states
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.activation = activation
- if self.activation not in ['relu', 'tanh']:
- raise RuntimeError(
- 'BasicRNNCell with unknown activation function (%s)'
- % self.activation)
- def apply_override(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- hidden_t_prev = states[0]
- gates_t = brew.fc(
- model,
- hidden_t_prev,
- 'gates_t',
- dim_in=self.hidden_size,
- dim_out=self.hidden_size,
- axis=2,
- )
- brew.sum(model, [gates_t, input_t], gates_t)
- if self.activation == 'tanh':
- hidden_t = model.net.Tanh(gates_t, 'hidden_t')
- elif self.activation == 'relu':
- hidden_t = model.net.Relu(gates_t, 'hidden_t')
- else:
- raise RuntimeError(
- 'BasicRNNCell with unknown activation function (%s)'
- % self.activation)
- if seq_lengths is not None:
- # TODO If this codepath becomes popular, it may be worth
- # taking a look at optimizing it - for now a simple
- # implementation is used to round out compatibility with
- # ONNX.
- timestep = model.net.CopyFromCPUInput(
- timestep, 'timestep_gpu')
- valid_b = model.net.GT(
- [seq_lengths, timestep], 'valid_b', broadcast=1)
- invalid_b = model.net.LE(
- [seq_lengths, timestep], 'invalid_b', broadcast=1)
- valid = model.net.Cast(valid_b, 'valid', to='float')
- invalid = model.net.Cast(invalid_b, 'invalid', to='float')
- hidden_valid = model.net.Mul(
- [hidden_t, valid],
- 'hidden_valid',
- broadcast=1,
- axis=1,
- )
- if self.drop_states:
- hidden_t = hidden_valid
- else:
- hidden_invalid = model.net.Mul(
- [hidden_t_prev, invalid],
- 'hidden_invalid',
- broadcast=1, axis=1)
- hidden_t = model.net.Add(
- [hidden_valid, hidden_invalid], hidden_t)
- return (hidden_t,)
- def prepare_input(self, model, input_blob):
- return brew.fc(
- model,
- input_blob,
- self.scope('i2h'),
- dim_in=self.input_size,
- dim_out=self.hidden_size,
- axis=2,
- )
- def get_state_names(self):
- return (self.scope('hidden_t'),)
- def get_output_dim(self):
- return self.hidden_size
- class LSTMCell(RNNCell):
- def __init__(
- self,
- input_size,
- hidden_size,
- forget_bias,
- memory_optimization,
- drop_states=False,
- initializer=None,
- **kwargs
- ):
- super(LSTMCell, self).__init__(initializer=initializer, **kwargs)
- self.initializer = initializer or LSTMInitializer(
- hidden_size=hidden_size)
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.forget_bias = float(forget_bias)
- self.memory_optimization = memory_optimization
- self.drop_states = drop_states
- self.gates_size = 4 * self.hidden_size
- def apply_override(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- hidden_t_prev, cell_t_prev = states
- fc_input = hidden_t_prev
- fc_input_dim = self.hidden_size
- if extra_inputs is not None:
- extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
- fc_input = brew.concat(
- model,
- [hidden_t_prev] + list(extra_input_blobs),
- 'gates_concatenated_input_t',
- axis=2,
- )
- fc_input_dim += sum(extra_input_sizes)
- gates_t = brew.fc(
- model,
- fc_input,
- 'gates_t',
- dim_in=fc_input_dim,
- dim_out=self.gates_size,
- axis=2,
- )
- brew.sum(model, [gates_t, input_t], gates_t)
- if seq_lengths is not None:
- inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
- else:
- inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
- hidden_t, cell_t = model.net.LSTMUnit(
- inputs,
- ['hidden_state', 'cell_state'],
- forget_bias=self.forget_bias,
- drop_states=self.drop_states,
- sequence_lengths=(seq_lengths is not None),
- )
- model.net.AddExternalOutputs(hidden_t, cell_t)
- if self.memory_optimization:
- self.recompute_blobs = [gates_t]
- return hidden_t, cell_t
- def get_input_params(self):
- return {
- 'weights': self.scope('i2h') + '_w',
- 'biases': self.scope('i2h') + '_b',
- }
- def get_recurrent_params(self):
- return {
- 'weights': self.scope('gates_t') + '_w',
- 'biases': self.scope('gates_t') + '_b',
- }
- def prepare_input(self, model, input_blob):
- return brew.fc(
- model,
- input_blob,
- self.scope('i2h'),
- dim_in=self.input_size,
- dim_out=self.gates_size,
- axis=2,
- )
- def get_state_names_override(self):
- return ['hidden_t', 'cell_t']
- def get_output_dim(self):
- return self.hidden_size
- class LayerNormLSTMCell(RNNCell):
- def __init__(
- self,
- input_size,
- hidden_size,
- forget_bias,
- memory_optimization,
- drop_states=False,
- initializer=None,
- **kwargs
- ):
- super(LayerNormLSTMCell, self).__init__(
- initializer=initializer, **kwargs
- )
- self.initializer = initializer or LSTMInitializer(
- hidden_size=hidden_size
- )
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.forget_bias = float(forget_bias)
- self.memory_optimization = memory_optimization
- self.drop_states = drop_states
- self.gates_size = 4 * self.hidden_size
- def _apply(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- hidden_t_prev, cell_t_prev = states
- fc_input = hidden_t_prev
- fc_input_dim = self.hidden_size
- if extra_inputs is not None:
- extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
- fc_input = brew.concat(
- model,
- [hidden_t_prev] + list(extra_input_blobs),
- self.scope('gates_concatenated_input_t'),
- axis=2,
- )
- fc_input_dim += sum(extra_input_sizes)
- gates_t = brew.fc(
- model,
- fc_input,
- self.scope('gates_t'),
- dim_in=fc_input_dim,
- dim_out=self.gates_size,
- axis=2,
- )
- brew.sum(model, [gates_t, input_t], gates_t)
- # brew.layer_norm call is only difference from LSTMCell
- gates_t, _, _ = brew.layer_norm(
- model,
- self.scope('gates_t'),
- self.scope('gates_t_norm'),
- dim_in=self.gates_size,
- axis=-1,
- )
- hidden_t, cell_t = model.net.LSTMUnit(
- [
- hidden_t_prev,
- cell_t_prev,
- gates_t,
- seq_lengths,
- timestep,
- ],
- self.get_state_names(),
- forget_bias=self.forget_bias,
- drop_states=self.drop_states,
- )
- model.net.AddExternalOutputs(hidden_t, cell_t)
- if self.memory_optimization:
- self.recompute_blobs = [gates_t]
- return hidden_t, cell_t
- def get_input_params(self):
- return {
- 'weights': self.scope('i2h') + '_w',
- 'biases': self.scope('i2h') + '_b',
- }
- def prepare_input(self, model, input_blob):
- return brew.fc(
- model,
- input_blob,
- self.scope('i2h'),
- dim_in=self.input_size,
- dim_out=self.gates_size,
- axis=2,
- )
- def get_state_names(self):
- return (self.scope('hidden_t'), self.scope('cell_t'))
- class MILSTMCell(LSTMCell):
- def _apply(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- hidden_t_prev, cell_t_prev = states
- fc_input = hidden_t_prev
- fc_input_dim = self.hidden_size
- if extra_inputs is not None:
- extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
- fc_input = brew.concat(
- model,
- [hidden_t_prev] + list(extra_input_blobs),
- self.scope('gates_concatenated_input_t'),
- axis=2,
- )
- fc_input_dim += sum(extra_input_sizes)
- prev_t = brew.fc(
- model,
- fc_input,
- self.scope('prev_t'),
- dim_in=fc_input_dim,
- dim_out=self.gates_size,
- axis=2,
- )
- # defining initializers for MI parameters
- alpha = model.create_param(
- self.scope('alpha'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=1.0),
- )
- beta_h = model.create_param(
- self.scope('beta1'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=1.0),
- )
- beta_i = model.create_param(
- self.scope('beta2'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=1.0),
- )
- b = model.create_param(
- self.scope('b'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=0.0),
- )
- # alpha * input_t + beta_h
- # Shape: [1, batch_size, 4 * hidden_size]
- alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
- [input_t, alpha, beta_h],
- self.scope('alpha_by_input_t_plus_beta_h'),
- axis=2,
- )
- # (alpha * input_t + beta_h) * prev_t =
- # alpha * input_t * prev_t + beta_h * prev_t
- # Shape: [1, batch_size, 4 * hidden_size]
- alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
- [alpha_by_input_t_plus_beta_h, prev_t],
- self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
- )
- # beta_i * input_t + b
- # Shape: [1, batch_size, 4 * hidden_size]
- beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
- [input_t, beta_i, b],
- self.scope('beta_i_by_input_t_plus_b'),
- axis=2,
- )
- # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
- # Shape: [1, batch_size, 4 * hidden_size]
- gates_t = brew.sum(
- model,
- [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
- self.scope('gates_t')
- )
- hidden_t, cell_t = model.net.LSTMUnit(
- [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
- [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
- forget_bias=self.forget_bias,
- drop_states=self.drop_states,
- )
- model.net.AddExternalOutputs(
- cell_t,
- hidden_t,
- )
- if self.memory_optimization:
- self.recompute_blobs = [gates_t]
- return hidden_t, cell_t
- class LayerNormMILSTMCell(LSTMCell):
- def _apply(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- hidden_t_prev, cell_t_prev = states
- fc_input = hidden_t_prev
- fc_input_dim = self.hidden_size
- if extra_inputs is not None:
- extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
- fc_input = brew.concat(
- model,
- [hidden_t_prev] + list(extra_input_blobs),
- self.scope('gates_concatenated_input_t'),
- axis=2,
- )
- fc_input_dim += sum(extra_input_sizes)
- prev_t = brew.fc(
- model,
- fc_input,
- self.scope('prev_t'),
- dim_in=fc_input_dim,
- dim_out=self.gates_size,
- axis=2,
- )
- # defining initializers for MI parameters
- alpha = model.create_param(
- self.scope('alpha'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=1.0),
- )
- beta_h = model.create_param(
- self.scope('beta1'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=1.0),
- )
- beta_i = model.create_param(
- self.scope('beta2'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=1.0),
- )
- b = model.create_param(
- self.scope('b'),
- shape=[self.gates_size],
- initializer=Initializer('ConstantFill', value=0.0),
- )
- # alpha * input_t + beta_h
- # Shape: [1, batch_size, 4 * hidden_size]
- alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
- [input_t, alpha, beta_h],
- self.scope('alpha_by_input_t_plus_beta_h'),
- axis=2,
- )
- # (alpha * input_t + beta_h) * prev_t =
- # alpha * input_t * prev_t + beta_h * prev_t
- # Shape: [1, batch_size, 4 * hidden_size]
- alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
- [alpha_by_input_t_plus_beta_h, prev_t],
- self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
- )
- # beta_i * input_t + b
- # Shape: [1, batch_size, 4 * hidden_size]
- beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
- [input_t, beta_i, b],
- self.scope('beta_i_by_input_t_plus_b'),
- axis=2,
- )
- # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
- # Shape: [1, batch_size, 4 * hidden_size]
- gates_t = brew.sum(
- model,
- [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
- self.scope('gates_t')
- )
- # brew.layer_norm call is only difference from MILSTMCell._apply
- gates_t, _, _ = brew.layer_norm(
- model,
- self.scope('gates_t'),
- self.scope('gates_t_norm'),
- dim_in=self.gates_size,
- axis=-1,
- )
- hidden_t, cell_t = model.net.LSTMUnit(
- [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
- [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
- forget_bias=self.forget_bias,
- drop_states=self.drop_states,
- )
- model.net.AddExternalOutputs(
- cell_t,
- hidden_t,
- )
- if self.memory_optimization:
- self.recompute_blobs = [gates_t]
- return hidden_t, cell_t
- class DropoutCell(RNNCell):
- '''
- Wraps arbitrary RNNCell, applying dropout to its output (but not to the
- recurrent connection for the corresponding state).
- '''
- def __init__(
- self,
- internal_cell,
- dropout_ratio=None,
- use_cudnn=False,
- **kwargs
- ):
- self.internal_cell = internal_cell
- self.dropout_ratio = dropout_ratio
- assert 'is_test' in kwargs, "Argument 'is_test' is required"
- self.is_test = kwargs.pop('is_test')
- self.use_cudnn = use_cudnn
- super(DropoutCell, self).__init__(**kwargs)
- self.prepare_input = internal_cell.prepare_input
- self.get_output_state_index = internal_cell.get_output_state_index
- self.get_state_names = internal_cell.get_state_names
- self.get_output_dim = internal_cell.get_output_dim
- self.mask = 0
- def _apply(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- return self.internal_cell._apply(
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs,
- )
- def _prepare_output(self, model, states):
- output = self.internal_cell._prepare_output(
- model,
- states,
- )
- if self.dropout_ratio is not None:
- output = self._apply_dropout(model, output)
- return output
- def _prepare_output_sequence(self, model, state_outputs):
- output = self.internal_cell._prepare_output_sequence(
- model,
- state_outputs,
- )
- if self.dropout_ratio is not None:
- output = self._apply_dropout(model, output)
- return output
- def _apply_dropout(self, model, output):
- if self.dropout_ratio and not self.forward_only:
- with core.NameScope(self.name or ''):
- output = brew.dropout(
- model,
- output,
- str(output) + '_with_dropout_mask{}'.format(self.mask),
- ratio=float(self.dropout_ratio),
- is_test=self.is_test,
- use_cudnn=self.use_cudnn,
- )
- self.mask += 1
- return output
- class MultiRNNCellInitializer(object):
- def __init__(self, cells):
- self.cells = cells
- def create_states(self, model):
- states = []
- for i, cell in enumerate(self.cells):
- if cell.initializer is None:
- raise Exception("Either initial states "
- "or initializer have to be set")
- with core.NameScope("layer_{}".format(i)),\
- core.NameScope(cell.name):
- states.extend(cell.initializer.create_states(model))
- return states
- class MultiRNNCell(RNNCell):
- '''
- Multilayer RNN via the composition of RNNCell instance.
- It is the responsibility of calling code to ensure the compatibility
- of the successive layers in terms of input/output dimensiality, etc.,
- and to ensure that their blobs do not have name conflicts, typically by
- creating the cells with names that specify layer number.
- Assumes first state (recurrent output) for each layer should be the input
- to the next layer.
- '''
- def __init__(self, cells, residual_output_layers=None, **kwargs):
- '''
- cells: list of RNNCell instances, from input to output side.
- name: string designating network component (for scoping)
- residual_output_layers: list of indices of layers whose input will
- be added elementwise to their output elementwise. (It is the
- responsibility of the client code to ensure shape compatibility.)
- Note that layer 0 (zero) cannot have residual output because of the
- timing of prepare_input().
- forward_only: used to construct inference-only network.
- '''
- super(MultiRNNCell, self).__init__(**kwargs)
- self.cells = cells
- if residual_output_layers is None:
- self.residual_output_layers = []
- else:
- self.residual_output_layers = residual_output_layers
- output_index_per_layer = []
- base_index = 0
- for cell in self.cells:
- output_index_per_layer.append(
- base_index + cell.get_output_state_index(),
- )
- base_index += len(cell.get_state_names())
- self.output_connected_layers = []
- self.output_indices = []
- for i in range(len(self.cells) - 1):
- if (i + 1) in self.residual_output_layers:
- self.output_connected_layers.append(i)
- self.output_indices.append(output_index_per_layer[i])
- else:
- self.output_connected_layers = []
- self.output_indices = []
- self.output_connected_layers.append(len(self.cells) - 1)
- self.output_indices.append(output_index_per_layer[-1])
- self.state_names = []
- for i, cell in enumerate(self.cells):
- self.state_names.extend(
- map(self.layer_scoper(i), cell.get_state_names())
- )
- self.initializer = MultiRNNCellInitializer(cells)
- def layer_scoper(self, layer_id):
- def helper(name):
- return "{}/layer_{}/{}".format(self.name, layer_id, name)
- return helper
- def prepare_input(self, model, input_blob):
- input_blob = _RectifyName(input_blob)
- with core.NameScope(self.name or ''):
- return self.cells[0].prepare_input(model, input_blob)
- def _apply(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- '''
- Because below we will do scoping across layers, we need
- to make sure that string blob names are convereted to BlobReference
- objects.
- '''
- input_t, seq_lengths, states, timestep, extra_inputs = \
- self._rectify_apply_inputs(
- input_t, seq_lengths, states, timestep, extra_inputs)
- states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
- assert len(states) == sum(states_per_layer)
- next_states = []
- states_index = 0
- layer_input = input_t
- for i, layer_cell in enumerate(self.cells):
- # # If cells don't have different names we still
- # take care of scoping
- with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
- num_states = states_per_layer[i]
- layer_states = states[states_index:(states_index + num_states)]
- states_index += num_states
- if i > 0:
- prepared_input = layer_cell.prepare_input(
- model, layer_input)
- else:
- prepared_input = layer_input
- layer_next_states = layer_cell._apply(
- model,
- prepared_input,
- seq_lengths,
- layer_states,
- timestep,
- extra_inputs=(None if i > 0 else extra_inputs),
- )
- # Since we're using here non-public method _apply,
- # instead of apply, we have to manually extract output
- # from states
- if i != len(self.cells) - 1:
- layer_output = layer_cell._prepare_output(
- model,
- layer_next_states,
- )
- if i > 0 and i in self.residual_output_layers:
- layer_input = brew.sum(
- model,
- [layer_output, layer_input],
- self.scope('residual_output_{}'.format(i)),
- )
- else:
- layer_input = layer_output
- next_states.extend(layer_next_states)
- return next_states
- def get_state_names(self):
- return self.state_names
- def get_output_state_index(self):
- index = 0
- for cell in self.cells[:-1]:
- index += len(cell.get_state_names())
- index += self.cells[-1].get_output_state_index()
- return index
- def _prepare_output(self, model, states):
- connected_outputs = []
- state_index = 0
- for i, cell in enumerate(self.cells):
- num_states = len(cell.get_state_names())
- if i in self.output_connected_layers:
- layer_states = states[state_index:state_index + num_states]
- layer_output = cell._prepare_output(
- model,
- layer_states
- )
- connected_outputs.append(layer_output)
- state_index += num_states
- if len(connected_outputs) > 1:
- output = brew.sum(
- model,
- connected_outputs,
- self.scope('residual_output'),
- )
- else:
- output = connected_outputs[0]
- return output
- def _prepare_output_sequence(self, model, states):
- connected_outputs = []
- state_index = 0
- for i, cell in enumerate(self.cells):
- num_states = 2 * len(cell.get_state_names())
- if i in self.output_connected_layers:
- layer_states = states[state_index:state_index + num_states]
- layer_output = cell._prepare_output_sequence(
- model,
- layer_states
- )
- connected_outputs.append(layer_output)
- state_index += num_states
- if len(connected_outputs) > 1:
- output = brew.sum(
- model,
- connected_outputs,
- self.scope('residual_output_sequence'),
- )
- else:
- output = connected_outputs[0]
- return output
- class AttentionCell(RNNCell):
- def __init__(
- self,
- encoder_output_dim,
- encoder_outputs,
- encoder_lengths,
- decoder_cell,
- decoder_state_dim,
- attention_type,
- weighted_encoder_outputs,
- attention_memory_optimization,
- **kwargs
- ):
- super(AttentionCell, self).__init__(**kwargs)
- self.encoder_output_dim = encoder_output_dim
- self.encoder_outputs = encoder_outputs
- self.encoder_lengths = encoder_lengths
- self.decoder_cell = decoder_cell
- self.decoder_state_dim = decoder_state_dim
- self.weighted_encoder_outputs = weighted_encoder_outputs
- self.encoder_outputs_transposed = None
- assert attention_type in [
- AttentionType.Regular,
- AttentionType.Recurrent,
- AttentionType.Dot,
- AttentionType.SoftCoverage,
- ]
- self.attention_type = attention_type
- self.attention_memory_optimization = attention_memory_optimization
- def _apply(
- self,
- model,
- input_t,
- seq_lengths,
- states,
- timestep,
- extra_inputs=None,
- ):
- if self.attention_type == AttentionType.SoftCoverage:
- decoder_prev_states = states[:-2]
- attention_weighted_encoder_context_t_prev = states[-2]
- coverage_t_prev = states[-1]
- else:
- decoder_prev_states = states[:-1]
- attention_weighted_encoder_context_t_prev = states[-1]
- assert extra_inputs is None
- decoder_states = self.decoder_cell._apply(
- model,
- input_t,
- seq_lengths,
- decoder_prev_states,
- timestep,
- extra_inputs=[(
- attention_weighted_encoder_context_t_prev,
- self.encoder_output_dim,
- )],
- )
- self.hidden_t_intermediate = self.decoder_cell._prepare_output(
- model,
- decoder_states,
- )
- if self.attention_type == AttentionType.Recurrent:
- (
- attention_weighted_encoder_context_t,
- self.attention_weights_3d,
- attention_blobs,
- ) = apply_recurrent_attention(
- model=model,
- encoder_output_dim=self.encoder_output_dim,
- encoder_outputs_transposed=self.encoder_outputs_transposed,
- weighted_encoder_outputs=self.weighted_encoder_outputs,
- decoder_hidden_state_t=self.hidden_t_intermediate,
- decoder_hidden_state_dim=self.decoder_state_dim,
- scope=self.name,
- attention_weighted_encoder_context_t_prev=(
- attention_weighted_encoder_context_t_prev
- ),
- encoder_lengths=self.encoder_lengths,
- )
- elif self.attention_type == AttentionType.Regular:
- (
- attention_weighted_encoder_context_t,
- self.attention_weights_3d,
- attention_blobs,
- ) = apply_regular_attention(
- model=model,
- encoder_output_dim=self.encoder_output_dim,
- encoder_outputs_transposed=self.encoder_outputs_transposed,
- weighted_encoder_outputs=self.weighted_encoder_outputs,
- decoder_hidden_state_t=self.hidden_t_intermediate,
- decoder_hidden_state_dim=self.decoder_state_dim,
- scope=self.name,
- encoder_lengths=self.encoder_lengths,
- )
- elif self.attention_type == AttentionType.Dot:
- (
- attention_weighted_encoder_context_t,
- self.attention_weights_3d,
- attention_blobs,
- ) = apply_dot_attention(
- model=model,
- encoder_output_dim=self.encoder_output_dim,
- encoder_outputs_transposed=self.encoder_outputs_transposed,
- decoder_hidden_state_t=self.hidden_t_intermediate,
- decoder_hidden_state_dim=self.decoder_state_dim,
- scope=self.name,
- encoder_lengths=self.encoder_lengths,
- )
- elif self.attention_type == AttentionType.SoftCoverage:
- (
- attention_weighted_encoder_context_t,
- self.attention_weights_3d,
- attention_blobs,
- coverage_t,
- ) = apply_soft_coverage_attention(
- model=model,
- encoder_output_dim=self.encoder_output_dim,
- encoder_outputs_transposed=self.encoder_outputs_transposed,
- weighted_encoder_outputs=self.weighted_encoder_outputs,
- decoder_hidden_state_t=self.hidden_t_intermediate,
- decoder_hidden_state_dim=self.decoder_state_dim,
- scope=self.name,
- encoder_lengths=self.encoder_lengths,
- coverage_t_prev=coverage_t_prev,
- coverage_weights=self.coverage_weights,
- )
- else:
- raise Exception('Attention type {} not implemented'.format(
- self.attention_type
- ))
- if self.attention_memory_optimization:
- self.recompute_blobs.extend(attention_blobs)
- output = list(decoder_states) + [attention_weighted_encoder_context_t]
- if self.attention_type == AttentionType.SoftCoverage:
- output.append(coverage_t)
- output[self.decoder_cell.get_output_state_index()] = model.Copy(
- output[self.decoder_cell.get_output_state_index()],
- self.scope('hidden_t_external'),
- )
- model.net.AddExternalOutputs(*output)
- return output
- def get_attention_weights(self):
- # [batch_size, encoder_length, 1]
- return self.attention_weights_3d
- def prepare_input(self, model, input_blob):
- if self.encoder_outputs_transposed is None:
- self.encoder_outputs_transposed = brew.transpose(
- model,
- self.encoder_outputs,
- self.scope('encoder_outputs_transposed'),
- axes=[1, 2, 0],
- )
- if (
- self.weighted_encoder_outputs is None and
- self.attention_type != AttentionType.Dot
- ):
- self.weighted_encoder_outputs = brew.fc(
- model,
- self.encoder_outputs,
- self.scope('weighted_encoder_outputs'),
- dim_in=self.encoder_output_dim,
- dim_out=self.encoder_output_dim,
- axis=2,
- )
- return self.decoder_cell.prepare_input(model, input_blob)
- def build_initial_coverage(self, model):
- """
- initial_coverage is always zeros of shape [encoder_length],
- which shape must be determined programmatically dureing network
- computation.
- This method also sets self.coverage_weights, a separate transform
- of encoder_outputs which is used to determine coverage contribution
- tp attention.
- """
- assert self.attention_type == AttentionType.SoftCoverage
- # [encoder_length, batch_size, encoder_output_dim]
- self.coverage_weights = brew.fc(
- model,
- self.encoder_outputs,
- self.scope('coverage_weights'),
- dim_in=self.encoder_output_dim,
- dim_out=self.encoder_output_dim,
- axis=2,
- )
- encoder_length = model.net.Slice(
- model.net.Shape(self.encoder_outputs),
- starts=[0],
- ends=[1],
- )
- if (
- scope.CurrentDeviceScope() is not None and
- core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
- ):
- encoder_length = model.net.CopyGPUToCPU(
- encoder_length,
- 'encoder_length_cpu',
- )
- # total attention weight applied across decoding steps_per_checkpoint
- # shape: [encoder_length]
- initial_coverage = model.net.ConstantFill(
- encoder_length,
- self.scope('initial_coverage'),
- value=0.0,
- input_as_shape=1,
- )
- return initial_coverage
- def get_state_names(self):
- state_names = list(self.decoder_cell.get_state_names())
- state_names[self.get_output_state_index()] = self.scope(
- 'hidden_t_external',
- )
- state_names.append(self.scope('attention_weighted_encoder_context_t'))
- if self.attention_type == AttentionType.SoftCoverage:
- state_names.append(self.scope('coverage_t'))
- return state_names
- def get_output_dim(self):
- return self.decoder_state_dim + self.encoder_output_dim
- def get_output_state_index(self):
- return self.decoder_cell.get_output_state_index()
- def _prepare_output(self, model, states):
- if self.attention_type == AttentionType.SoftCoverage:
- attention_context = states[-2]
- else:
- attention_context = states[-1]
- with core.NameScope(self.name or ''):
- output = brew.concat(
- model,
- [self.hidden_t_intermediate, attention_context],
- 'states_and_context_combination',
- axis=2,
- )
- return output
- def _prepare_output_sequence(self, model, state_outputs):
- if self.attention_type == AttentionType.SoftCoverage:
- decoder_state_outputs = state_outputs[:-4]
- else:
- decoder_state_outputs = state_outputs[:-2]
- decoder_output = self.decoder_cell._prepare_output_sequence(
- model,
- decoder_state_outputs,
- )
- if self.attention_type == AttentionType.SoftCoverage:
- attention_context_index = 2 * (len(self.get_state_names()) - 2)
- else:
- attention_context_index = 2 * (len(self.get_state_names()) - 1)
- with core.NameScope(self.name or ''):
- output = brew.concat(
- model,
- [
- decoder_output,
- state_outputs[attention_context_index],
- ],
- 'states_and_context_combination',
- axis=2,
- )
- return output
- class LSTMWithAttentionCell(AttentionCell):
- def __init__(
- self,
- encoder_output_dim,
- encoder_outputs,
- encoder_lengths,
- decoder_input_dim,
- decoder_state_dim,
- name,
- attention_type,
- weighted_encoder_outputs,
- forget_bias,
- lstm_memory_optimization,
- attention_memory_optimization,
- forward_only=False,
- ):
- decoder_cell = LSTMCell(
- input_size=decoder_input_dim,
- hidden_size=decoder_state_dim,
- forget_bias=forget_bias,
- memory_optimization=lstm_memory_optimization,
- name='{}/decoder'.format(name),
- forward_only=False,
- drop_states=False,
- )
- super(LSTMWithAttentionCell, self).__init__(
- encoder_output_dim=encoder_output_dim,
- encoder_outputs=encoder_outputs,
- encoder_lengths=encoder_lengths,
- decoder_cell=decoder_cell,
- decoder_state_dim=decoder_state_dim,
- name=name,
- attention_type=attention_type,
- weighted_encoder_outputs=weighted_encoder_outputs,
- attention_memory_optimization=attention_memory_optimization,
- forward_only=forward_only,
- )
- class MILSTMWithAttentionCell(AttentionCell):
- def __init__(
- self,
- encoder_output_dim,
- encoder_outputs,
- decoder_input_dim,
- decoder_state_dim,
- name,
- attention_type,
- weighted_encoder_outputs,
- forget_bias,
- lstm_memory_optimization,
- attention_memory_optimization,
- forward_only=False,
- ):
- decoder_cell = MILSTMCell(
- input_size=decoder_input_dim,
- hidden_size=decoder_state_dim,
- forget_bias=forget_bias,
- memory_optimization=lstm_memory_optimization,
- name='{}/decoder'.format(name),
- forward_only=False,
- drop_states=False,
- )
- super(MILSTMWithAttentionCell, self).__init__(
- encoder_output_dim=encoder_output_dim,
- encoder_outputs=encoder_outputs,
- decoder_cell=decoder_cell,
- decoder_state_dim=decoder_state_dim,
- name=name,
- attention_type=attention_type,
- weighted_encoder_outputs=weighted_encoder_outputs,
- attention_memory_optimization=attention_memory_optimization,
- forward_only=forward_only,
- )
- def _LSTM(
- cell_class,
- model,
- input_blob,
- seq_lengths,
- initial_states,
- dim_in,
- dim_out,
- scope=None,
- outputs_with_grads=(0,),
- return_params=False,
- memory_optimization=False,
- forget_bias=0.0,
- forward_only=False,
- drop_states=False,
- return_last_layer_only=True,
- static_rnn_unroll_size=None,
- **cell_kwargs
- ):
- '''
- Adds a standard LSTM recurrent network operator to a model.
- cell_class: LSTMCell or compatible subclass
- model: ModelHelper object new operators would be added to
- input_blob: the input sequence in a format T x N x D
- where T is sequence size, N - batch size and D - input dimension
- seq_lengths: blob containing sequence lengths which would be passed to
- LSTMUnit operator
- initial_states: a list of (2 * num_layers) blobs representing the initial
- hidden and cell states of each layer. If this argument is None,
- these states will be added to the model as network parameters.
- dim_in: input dimension
- dim_out: number of units per LSTM layer
- (use int for single-layer LSTM, list of ints for multi-layer)
- outputs_with_grads : position indices of output blobs for LAST LAYER which
- will receive external error gradient during backpropagation.
- These outputs are: (h_all, h_last, c_all, c_last)
- return_params: if True, will return a dictionary of parameters of the LSTM
- memory_optimization: if enabled, the LSTM step is recomputed on backward
- step so that we don't need to store forward activations for each
- timestep. Saves memory with cost of computation.
- forget_bias: forget gate bias (default 0.0)
- forward_only: whether to create a backward pass
- drop_states: drop invalid states, passed through to LSTMUnit operator
- return_last_layer_only: only return outputs from final layer
- (so that length of results does depend on number of layers)
- static_rnn_unroll_size: if not None, we will use static RNN which is
- unrolled into Caffe2 graph. The size of the unroll is the value of
- this parameter.
- '''
- if type(dim_out) is not list and type(dim_out) is not tuple:
- dim_out = [dim_out]
- num_layers = len(dim_out)
- cells = []
- for i in range(num_layers):
- cell = cell_class(
- input_size=(dim_in if i == 0 else dim_out[i - 1]),
- hidden_size=dim_out[i],
- forget_bias=forget_bias,
- memory_optimization=memory_optimization,
- name=scope if num_layers == 1 else None,
- forward_only=forward_only,
- drop_states=drop_states,
- **cell_kwargs
- )
- cells.append(cell)
- cell = MultiRNNCell(
- cells,
- name=scope,
- forward_only=forward_only,
- ) if num_layers > 1 else cells[0]
- cell = (
- cell if static_rnn_unroll_size is None
- else UnrolledCell(cell, static_rnn_unroll_size))
- # outputs_with_grads argument indexes into final layer
- outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
- _, result = cell.apply_over_sequence(
- model=model,
- inputs=input_blob,
- seq_lengths=seq_lengths,
- initial_states=initial_states,
- outputs_with_grads=outputs_with_grads,
- )
- if return_last_layer_only:
- result = result[4 * (num_layers - 1):]
- if return_params:
- result = list(result) + [{
- 'input': cell.get_input_params(),
- 'recurrent': cell.get_recurrent_params(),
- }]
- return tuple(result)
- LSTM = functools.partial(_LSTM, LSTMCell)
- BasicRNN = functools.partial(_LSTM, BasicRNNCell)
- MILSTM = functools.partial(_LSTM, MILSTMCell)
- LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
- LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
- class UnrolledCell(RNNCell):
- def __init__(self, cell, T):
- self.T = T
- self.cell = cell
- def apply_over_sequence(
- self,
- model,
- inputs,
- seq_lengths,
- initial_states,
- outputs_with_grads=None,
- ):
- inputs = self.cell.prepare_input(model, inputs)
- # Now they are blob references - outputs of splitting the input sequence
- split_inputs = model.net.Split(
- inputs,
- [str(inputs) + "_timestep_{}".format(i)
- for i in range(self.T)],
- axis=0)
- if self.T == 1:
- split_inputs = [split_inputs]
- states = initial_states
- all_states = []
- for t in range(0, self.T):
- scope_name = "timestep_{}".format(t)
- # Parameters of all timesteps are shared
- with ParameterSharing({scope_name: ''}),\
- scope.NameScope(scope_name):
- timestep = model.param_init_net.ConstantFill(
- [], "timestep", value=t, shape=[1],
- dtype=core.DataType.INT32,
- device_option=core.DeviceOption(caffe2_pb2.CPU))
- states = self.cell._apply(
- model=model,
- input_t=split_inputs[t],
- seq_lengths=seq_lengths,
- states=states,
- timestep=timestep,
- )
- all_states.append(states)
- all_states = zip(*all_states)
- all_states = [
- model.net.Concat(
- list(full_output),
- [
- str(full_output[0])[len("timestep_0/"):] + "_concat",
- str(full_output[0])[len("timestep_0/"):] + "_concat_info"
- ],
- axis=0)[0]
- for full_output in all_states
- ]
- # Interleave the state values similar to
- #
- # x = [1, 3, 5]
- # y = [2, 4, 6]
- # z = [val for pair in zip(x, y) for val in pair]
- # # z is [1, 2, 3, 4, 5, 6]
- #
- # and returns it as outputs
- outputs = tuple(
- state for state_pair in zip(all_states, states) for state in state_pair
- )
- outputs_without_grad = set(range(len(outputs))) - set(
- outputs_with_grads)
- for i in outputs_without_grad:
- model.net.ZeroGradient(outputs[i], [])
- logging.debug("Added 0 gradients for blobs:",
- [outputs[i] for i in outputs_without_grad])
- final_output = self.cell._prepare_output_sequence(model, outputs)
- return final_output, outputs
- def GetLSTMParamNames():
- weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
- bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
- return {'weights': weight_params, 'biases': bias_params}
- def InitFromLSTMParams(lstm_pblobs, param_values):
- '''
- Set the parameters of LSTM based on predefined values
- '''
- weight_params = GetLSTMParamNames()['weights']
- bias_params = GetLSTMParamNames()['biases']
- for input_type in viewkeys(param_values):
- weight_values = [
- param_values[input_type][w].flatten()
- for w in weight_params
- ]
- wmat = np.array([])
- for w in weight_values:
- wmat = np.append(wmat, w)
- bias_values = [
- param_values[input_type][b].flatten()
- for b in bias_params
- ]
- bm = np.array([])
- for b in bias_values:
- bm = np.append(bm, b)
- weights_blob = lstm_pblobs[input_type]['weights']
- bias_blob = lstm_pblobs[input_type]['biases']
- cur_weight = workspace.FetchBlob(weights_blob)
- cur_biases = workspace.FetchBlob(bias_blob)
- workspace.FeedBlob(
- weights_blob,
- wmat.reshape(cur_weight.shape).astype(np.float32))
- workspace.FeedBlob(
- bias_blob,
- bm.reshape(cur_biases.shape).astype(np.float32))
- def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
- scope, recurrent_params=None, input_params=None,
- num_layers=1, return_params=False):
- '''
- CuDNN version of LSTM for GPUs.
- input_blob Blob containing the input. Will need to be available
- when param_init_net is run, because the sequence lengths
- and batch sizes will be inferred from the size of this
- blob.
- initial_states tuple of (hidden_init, cell_init) blobs
- dim_in input dimensions
- dim_out output/hidden dimension
- scope namescope to apply
- recurrent_params dict of blobs containing values for recurrent
- gate weights, biases (if None, use random init values)
- See GetLSTMParamNames() for format.
- input_params dict of blobs containing values for input
- gate weights, biases (if None, use random init values)
- See GetLSTMParamNames() for format.
- num_layers number of LSTM layers
- return_params if True, returns (param_extract_net, param_mapping)
- where param_extract_net is a net that when run, will
- populate the blobs specified in param_mapping with the
- current gate weights and biases (input/recurrent).
- Useful for assigning the values back to non-cuDNN
- LSTM.
- '''
- with core.NameScope(scope):
- weight_params = GetLSTMParamNames()['weights']
- bias_params = GetLSTMParamNames()['biases']
- input_weight_size = dim_out * dim_in
- upper_layer_input_weight_size = dim_out * dim_out
- recurrent_weight_size = dim_out * dim_out
- input_bias_size = dim_out
- recurrent_bias_size = dim_out
- def init(layer, pname, input_type):
- input_weight_size_for_layer = input_weight_size if layer == 0 else \
- upper_layer_input_weight_size
- if pname in weight_params:
- sz = input_weight_size_for_layer if input_type == 'input' \
- else recurrent_weight_size
- elif pname in bias_params:
- sz = input_bias_size if input_type == 'input' \
- else recurrent_bias_size
- else:
- assert False, "unknown parameter type {}".format(pname)
- return model.param_init_net.UniformFill(
- [],
- "lstm_init_{}_{}_{}".format(input_type, pname, layer),
- shape=[sz])
- # Multiply by 4 since we have 4 gates per LSTM unit
- first_layer_sz = input_weight_size + recurrent_weight_size + \
- input_bias_size + recurrent_bias_size
- upper_layer_sz = upper_layer_input_weight_size + \
- recurrent_weight_size + input_bias_size + \
- recurrent_bias_size
- total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
- weights = model.create_param(
- 'lstm_weight',
- shape=[total_sz],
- initializer=Initializer('UniformFill'),
- tags=ParameterTags.WEIGHT,
- )
- lstm_args = {
- 'hidden_size': dim_out,
- 'rnn_mode': 'lstm',
- 'bidirectional': 0, # TODO
- 'dropout': 1.0, # TODO
- 'input_mode': 'linear', # TODO
- 'num_layers': num_layers,
- 'engine': 'CUDNN'
- }
- param_extract_net = core.Net("lstm_param_extractor")
- param_extract_net.AddExternalInputs([input_blob, weights])
- param_extract_mapping = {}
- # Populate the weights-blob from blobs containing parameters for
- # the individual components of the LSTM, such as forget/input gate
- # weights and bises. Also, create a special param_extract_net that
- # can be used to grab those individual params from the black-box
- # weights blob. These results can be then fed to InitFromLSTMParams()
- for input_type in ['input', 'recurrent']:
- param_extract_mapping[input_type] = {}
- p = recurrent_params if input_type == 'recurrent' else input_params
- if p is None:
- p = {}
- for pname in weight_params + bias_params:
- for j in range(0, num_layers):
- values = p[pname] if pname in p else init(j, pname, input_type)
- model.param_init_net.RecurrentParamSet(
- [input_blob, weights, values],
- weights,
- layer=j,
- input_type=input_type,
- param_type=pname,
- **lstm_args
- )
- if pname not in param_extract_mapping[input_type]:
- param_extract_mapping[input_type][pname] = {}
- b = param_extract_net.RecurrentParamGet(
- [input_blob, weights],
- ["lstm_{}_{}_{}".format(input_type, pname, j)],
- layer=j,
- input_type=input_type,
- param_type=pname,
- **lstm_args
- )
- param_extract_mapping[input_type][pname][j] = b
- (hidden_input_blob, cell_input_blob) = initial_states
- output, hidden_output, cell_output, rnn_scratch, dropout_states = \
- model.net.Recurrent(
- [input_blob, hidden_input_blob, cell_input_blob, weights],
- ["lstm_output", "lstm_hidden_output", "lstm_cell_output",
- "lstm_rnn_scratch", "lstm_dropout_states"],
- seed=random.randint(0, 100000), # TODO: dropout seed
- **lstm_args
- )
- model.net.AddExternalOutputs(
- hidden_output, cell_output, rnn_scratch, dropout_states)
- if return_params:
- param_extract = param_extract_net, param_extract_mapping
- return output, hidden_output, cell_output, param_extract
- else:
- return output, hidden_output, cell_output
- def LSTMWithAttention(
- model,
- decoder_inputs,
- decoder_input_lengths,
- initial_decoder_hidden_state,
- initial_decoder_cell_state,
- initial_attention_weighted_encoder_context,
- encoder_output_dim,
- encoder_outputs,
- encoder_lengths,
- decoder_input_dim,
- decoder_state_dim,
- scope,
- attention_type=AttentionType.Regular,
- outputs_with_grads=(0, 4),
- weighted_encoder_outputs=None,
- lstm_memory_optimization=False,
- attention_memory_optimization=False,
- forget_bias=0.0,
- forward_only=False,
- ):
- '''
- Adds a LSTM with attention mechanism to a model.
- The implementation is based on https://arxiv.org/abs/1409.0473, with
- a small difference in the order
- how we compute new attention context and new hidden state, similarly to
- https://arxiv.org/abs/1508.04025.
- The model uses encoder-decoder naming conventions,
- where the decoder is the sequence the op is iterating over,
- while computing the attention context over the encoder.
- model: ModelHelper object new operators would be added to
- decoder_inputs: the input sequence in a format T x N x D
- where T is sequence size, N - batch size and D - input dimension
- decoder_input_lengths: blob containing sequence lengths
- which would be passed to LSTMUnit operator
- initial_decoder_hidden_state: initial hidden state of LSTM
- initial_decoder_cell_state: initial cell state of LSTM
- initial_attention_weighted_encoder_context: initial attention context
- encoder_output_dim: dimension of encoder outputs
- encoder_outputs: the sequence, on which we compute the attention context
- at every iteration
- encoder_lengths: a tensor with lengths of each encoder sequence in batch
- (may be None, meaning all encoder sequences are of same length)
- decoder_input_dim: input dimension (last dimension on decoder_inputs)
- decoder_state_dim: size of hidden states of LSTM
- attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
- Determines which type of attention mechanism to use.
- outputs_with_grads : position indices of output blobs which will receive
- external error gradient during backpropagation
- weighted_encoder_outputs: encoder outputs to be used to compute attention
- weights. In the basic case it's just linear transformation of
- encoder outputs (that the default, when weighted_encoder_outputs is None).
- However, it can be something more complicated - like a separate
- encoder network (for example, in case of convolutional encoder)
- lstm_memory_optimization: recompute LSTM activations on backward pass, so
- we don't need to store their values in forward passes
- attention_memory_optimization: recompute attention for backward pass
- forward_only: whether to create only forward pass
- '''
- cell = LSTMWithAttentionCell(
- encoder_output_dim=encoder_output_dim,
- encoder_outputs=encoder_outputs,
- encoder_lengths=encoder_lengths,
- decoder_input_dim=decoder_input_dim,
- decoder_state_dim=decoder_state_dim,
- name=scope,
- attention_type=attention_type,
- weighted_encoder_outputs=weighted_encoder_outputs,
- forget_bias=forget_bias,
- lstm_memory_optimization=lstm_memory_optimization,
- attention_memory_optimization=attention_memory_optimization,
- forward_only=forward_only,
- )
- initial_states = [
- initial_decoder_hidden_state,
- initial_decoder_cell_state,
- initial_attention_weighted_encoder_context,
- ]
- if attention_type == AttentionType.SoftCoverage:
- initial_states.append(cell.build_initial_coverage(model))
- _, result = cell.apply_over_sequence(
- model=model,
- inputs=decoder_inputs,
- seq_lengths=decoder_input_lengths,
- initial_states=initial_states,
- outputs_with_grads=outputs_with_grads,
- )
- return result
- def _layered_LSTM(
- model, input_blob, seq_lengths, initial_states,
- dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
- memory_optimization=False, forget_bias=0.0, forward_only=False,
- drop_states=False, create_lstm=None):
- params = locals() # leave it as a first line to grab all params
- params.pop('create_lstm')
- if not isinstance(dim_out, list):
- return create_lstm(**params)
- elif len(dim_out) == 1:
- params['dim_out'] = dim_out[0]
- return create_lstm(**params)
- assert len(dim_out) != 0, "dim_out list can't be empty"
- assert return_params is False, "return_params not supported for layering"
- for i, output_dim in enumerate(dim_out):
- params.update({
- 'dim_out': output_dim
- })
- output, last_output, all_states, last_state = create_lstm(**params)
- params.update({
- 'input_blob': output,
- 'dim_in': output_dim,
- 'initial_states': (last_output, last_state),
- 'scope': scope + '_layer_{}'.format(i + 1)
- })
- return output, last_output, all_states, last_state
- layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)
|