gru_cell.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import functools
  2. from caffe2.python import brew, rnn_cell
  3. class GRUCell(rnn_cell.RNNCell):
  4. def __init__(
  5. self,
  6. input_size,
  7. hidden_size,
  8. forget_bias, # Currently unused! Values here will be ignored.
  9. memory_optimization,
  10. drop_states=False,
  11. linear_before_reset=False,
  12. **kwargs
  13. ):
  14. super(GRUCell, self).__init__(**kwargs)
  15. self.input_size = input_size
  16. self.hidden_size = hidden_size
  17. self.forget_bias = float(forget_bias)
  18. self.memory_optimization = memory_optimization
  19. self.drop_states = drop_states
  20. self.linear_before_reset = linear_before_reset
  21. # Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
  22. # (reset gate -> output_gate)
  23. # So, much of the logic to calculate the reset gate output and modified
  24. # output gate input is set here, in the graph definition.
  25. # The remaining logic lives in gru_unit_op.{h,cc}.
  26. def _apply(
  27. self,
  28. model,
  29. input_t,
  30. seq_lengths,
  31. states,
  32. timestep,
  33. extra_inputs=None,
  34. ):
  35. hidden_t_prev = states[0]
  36. # Split input tensors to get inputs for each gate.
  37. input_t_reset, input_t_update, input_t_output = model.net.Split(
  38. [
  39. input_t,
  40. ],
  41. [
  42. self.scope('input_t_reset'),
  43. self.scope('input_t_update'),
  44. self.scope('input_t_output'),
  45. ],
  46. axis=2,
  47. )
  48. # Fully connected layers for reset and update gates.
  49. reset_gate_t = brew.fc(
  50. model,
  51. hidden_t_prev,
  52. self.scope('reset_gate_t'),
  53. dim_in=self.hidden_size,
  54. dim_out=self.hidden_size,
  55. axis=2,
  56. )
  57. update_gate_t = brew.fc(
  58. model,
  59. hidden_t_prev,
  60. self.scope('update_gate_t'),
  61. dim_in=self.hidden_size,
  62. dim_out=self.hidden_size,
  63. axis=2,
  64. )
  65. # Calculating the modified hidden state going into output gate.
  66. reset_gate_t = model.net.Sum(
  67. [reset_gate_t, input_t_reset],
  68. self.scope('reset_gate_t')
  69. )
  70. reset_gate_t_sigmoid = model.net.Sigmoid(
  71. reset_gate_t,
  72. self.scope('reset_gate_t_sigmoid')
  73. )
  74. # `self.linear_before_reset = True` matches cudnn semantics
  75. if self.linear_before_reset:
  76. output_gate_fc = brew.fc(
  77. model,
  78. hidden_t_prev,
  79. self.scope('output_gate_t'),
  80. dim_in=self.hidden_size,
  81. dim_out=self.hidden_size,
  82. axis=2,
  83. )
  84. output_gate_t = model.net.Mul(
  85. [reset_gate_t_sigmoid, output_gate_fc],
  86. self.scope('output_gate_t_mul')
  87. )
  88. else:
  89. modified_hidden_t_prev = model.net.Mul(
  90. [reset_gate_t_sigmoid, hidden_t_prev],
  91. self.scope('modified_hidden_t_prev')
  92. )
  93. output_gate_t = brew.fc(
  94. model,
  95. modified_hidden_t_prev,
  96. self.scope('output_gate_t'),
  97. dim_in=self.hidden_size,
  98. dim_out=self.hidden_size,
  99. axis=2,
  100. )
  101. # Add input contributions to update and output gate.
  102. # We already (in-place) added input contributions to the reset gate.
  103. update_gate_t = model.net.Sum(
  104. [update_gate_t, input_t_update],
  105. self.scope('update_gate_t'),
  106. )
  107. output_gate_t = model.net.Sum(
  108. [output_gate_t, input_t_output],
  109. self.scope('output_gate_t_summed'),
  110. )
  111. # Join gate outputs and add input contributions
  112. gates_t, _gates_t_concat_dims = model.net.Concat(
  113. [
  114. reset_gate_t,
  115. update_gate_t,
  116. output_gate_t,
  117. ],
  118. [
  119. self.scope('gates_t'),
  120. self.scope('_gates_t_concat_dims'),
  121. ],
  122. axis=2,
  123. )
  124. if seq_lengths is not None:
  125. inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
  126. else:
  127. inputs = [hidden_t_prev, gates_t, timestep]
  128. hidden_t = model.net.GRUUnit(
  129. inputs,
  130. list(self.get_state_names()),
  131. forget_bias=self.forget_bias,
  132. drop_states=self.drop_states,
  133. sequence_lengths=(seq_lengths is not None),
  134. )
  135. model.net.AddExternalOutputs(hidden_t)
  136. return (hidden_t,)
  137. def prepare_input(self, model, input_blob):
  138. return brew.fc(
  139. model,
  140. input_blob,
  141. self.scope('i2h'),
  142. dim_in=self.input_size,
  143. dim_out=3 * self.hidden_size,
  144. axis=2,
  145. )
  146. def get_state_names(self):
  147. return (self.scope('hidden_t'),)
  148. def get_output_dim(self):
  149. return self.hidden_size
  150. GRU = functools.partial(rnn_cell._LSTM, GRUCell)