char_rnn.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. ## @package char_rnn
  2. # Module caffe2.python.examples.char_rnn
  3. from caffe2.python import core, workspace, model_helper, utils, brew
  4. from caffe2.python.rnn_cell import LSTM
  5. from caffe2.proto import caffe2_pb2
  6. from caffe2.python.optimizer import build_sgd
  7. import argparse
  8. import logging
  9. import numpy as np
  10. from datetime import datetime
  11. '''
  12. This script takes a text file as input and uses a recurrent neural network
  13. to learn to predict next character in a sequence.
  14. '''
  15. logging.basicConfig()
  16. log = logging.getLogger("char_rnn")
  17. log.setLevel(logging.DEBUG)
  18. # Default set() here is intentional as it would accumulate values like a global
  19. # variable
  20. def CreateNetOnce(net, created_names=set()): # noqa
  21. name = net.Name()
  22. if name not in created_names:
  23. created_names.add(name)
  24. workspace.CreateNet(net)
  25. class CharRNN(object):
  26. def __init__(self, args):
  27. self.seq_length = args.seq_length
  28. self.batch_size = args.batch_size
  29. self.iters_to_report = args.iters_to_report
  30. self.hidden_size = args.hidden_size
  31. with open(args.train_data) as f:
  32. self.text = f.read()
  33. self.vocab = list(set(self.text))
  34. self.char_to_idx = {ch: idx for idx, ch in enumerate(self.vocab)}
  35. self.idx_to_char = {idx: ch for idx, ch in enumerate(self.vocab)}
  36. self.D = len(self.char_to_idx)
  37. print("Input has {} characters. Total input size: {}".format(
  38. len(self.vocab), len(self.text)))
  39. def CreateModel(self):
  40. log.debug("Start training")
  41. model = model_helper.ModelHelper(name="char_rnn")
  42. input_blob, seq_lengths, hidden_init, cell_init, target = \
  43. model.net.AddExternalInputs(
  44. 'input_blob',
  45. 'seq_lengths',
  46. 'hidden_init',
  47. 'cell_init',
  48. 'target',
  49. )
  50. hidden_output_all, self.hidden_output, _, self.cell_state = LSTM(
  51. model, input_blob, seq_lengths, (hidden_init, cell_init),
  52. self.D, self.hidden_size, scope="LSTM")
  53. output = brew.fc(
  54. model,
  55. hidden_output_all,
  56. None,
  57. dim_in=self.hidden_size,
  58. dim_out=self.D,
  59. axis=2
  60. )
  61. # axis is 2 as first two are T (time) and N (batch size).
  62. # We treat them as one big batch of size T * N
  63. softmax = model.net.Softmax(output, 'softmax', axis=2)
  64. softmax_reshaped, _ = model.net.Reshape(
  65. softmax, ['softmax_reshaped', '_'], shape=[-1, self.D])
  66. # Create a copy of the current net. We will use it on the forward
  67. # pass where we don't need loss and backward operators
  68. self.forward_net = core.Net(model.net.Proto())
  69. xent = model.net.LabelCrossEntropy([softmax_reshaped, target], 'xent')
  70. # Loss is average both across batch and through time
  71. # Thats why the learning rate below is multiplied by self.seq_length
  72. loss = model.net.AveragedLoss(xent, 'loss')
  73. model.AddGradientOperators([loss])
  74. # use build_sdg function to build an optimizer
  75. build_sgd(
  76. model,
  77. base_learning_rate=0.1 * self.seq_length,
  78. policy="step",
  79. stepsize=1,
  80. gamma=0.9999
  81. )
  82. self.model = model
  83. self.predictions = softmax
  84. self.loss = loss
  85. self.prepare_state = core.Net("prepare_state")
  86. self.prepare_state.Copy(self.hidden_output, hidden_init)
  87. self.prepare_state.Copy(self.cell_state, cell_init)
  88. def _idx_at_pos(self, pos):
  89. return self.char_to_idx[self.text[pos]]
  90. def TrainModel(self):
  91. log.debug("Training model")
  92. workspace.RunNetOnce(self.model.param_init_net)
  93. # As though we predict the same probability for each character
  94. smooth_loss = -np.log(1.0 / self.D) * self.seq_length
  95. last_n_iter = 0
  96. last_n_loss = 0.0
  97. num_iter = 0
  98. N = len(self.text)
  99. # We split text into batch_size pieces. Each piece will be used only
  100. # by a corresponding batch during the training process
  101. text_block_positions = np.zeros(self.batch_size, dtype=np.int32)
  102. text_block_size = N // self.batch_size
  103. text_block_starts = list(range(0, N, text_block_size))
  104. text_block_sizes = [text_block_size] * self.batch_size
  105. text_block_sizes[self.batch_size - 1] += N % self.batch_size
  106. assert sum(text_block_sizes) == N
  107. # Writing to output states which will be copied to input
  108. # states within the loop below
  109. workspace.FeedBlob(self.hidden_output, np.zeros(
  110. [1, self.batch_size, self.hidden_size], dtype=np.float32
  111. ))
  112. workspace.FeedBlob(self.cell_state, np.zeros(
  113. [1, self.batch_size, self.hidden_size], dtype=np.float32
  114. ))
  115. workspace.CreateNet(self.prepare_state)
  116. # We iterate over text in a loop many times. Each time we peak
  117. # seq_length segment and feed it to LSTM as a sequence
  118. last_time = datetime.now()
  119. progress = 0
  120. while True:
  121. workspace.FeedBlob(
  122. "seq_lengths",
  123. np.array([self.seq_length] * self.batch_size,
  124. dtype=np.int32)
  125. )
  126. workspace.RunNet(self.prepare_state.Name())
  127. input = np.zeros(
  128. [self.seq_length, self.batch_size, self.D]
  129. ).astype(np.float32)
  130. target = np.zeros(
  131. [self.seq_length * self.batch_size]
  132. ).astype(np.int32)
  133. for e in range(self.batch_size):
  134. for i in range(self.seq_length):
  135. pos = text_block_starts[e] + text_block_positions[e]
  136. input[i][e][self._idx_at_pos(pos)] = 1
  137. target[i * self.batch_size + e] =\
  138. self._idx_at_pos((pos + 1) % N)
  139. text_block_positions[e] = (
  140. text_block_positions[e] + 1) % text_block_sizes[e]
  141. progress += 1
  142. workspace.FeedBlob('input_blob', input)
  143. workspace.FeedBlob('target', target)
  144. CreateNetOnce(self.model.net)
  145. workspace.RunNet(self.model.net.Name())
  146. num_iter += 1
  147. last_n_iter += 1
  148. if num_iter % self.iters_to_report == 0:
  149. new_time = datetime.now()
  150. print("Characters Per Second: {}". format(
  151. int(progress / (new_time - last_time).total_seconds())
  152. ))
  153. print("Iterations Per Second: {}". format(
  154. int(self.iters_to_report /
  155. (new_time - last_time).total_seconds())
  156. ))
  157. last_time = new_time
  158. progress = 0
  159. print("{} Iteration {} {}".
  160. format('-' * 10, num_iter, '-' * 10))
  161. loss = workspace.FetchBlob(self.loss) * self.seq_length
  162. smooth_loss = 0.999 * smooth_loss + 0.001 * loss
  163. last_n_loss += loss
  164. if num_iter % self.iters_to_report == 0:
  165. self.GenerateText(500, np.random.choice(self.vocab))
  166. log.debug("Loss since last report: {}"
  167. .format(last_n_loss / last_n_iter))
  168. log.debug("Smooth loss: {}".format(smooth_loss))
  169. last_n_loss = 0.0
  170. last_n_iter = 0
  171. def GenerateText(self, num_characters, ch):
  172. # Given a starting symbol we feed a fake sequence of size 1 to
  173. # our RNN num_character times. After each time we use output
  174. # probabilities to pick a next character to feed to the network.
  175. # Same character becomes part of the output
  176. CreateNetOnce(self.forward_net)
  177. text = '' + ch
  178. for _i in range(num_characters):
  179. workspace.FeedBlob(
  180. "seq_lengths", np.array([1] * self.batch_size, dtype=np.int32))
  181. workspace.RunNet(self.prepare_state.Name())
  182. input = np.zeros([1, self.batch_size, self.D]).astype(np.float32)
  183. input[0][0][self.char_to_idx[ch]] = 1
  184. workspace.FeedBlob("input_blob", input)
  185. workspace.RunNet(self.forward_net.Name())
  186. p = workspace.FetchBlob(self.predictions)
  187. next = np.random.choice(self.D, p=p[0][0])
  188. ch = self.idx_to_char[next]
  189. text += ch
  190. print(text)
  191. @utils.debug
  192. def main():
  193. parser = argparse.ArgumentParser(
  194. description="Caffe2: Char RNN Training"
  195. )
  196. parser.add_argument("--train_data", type=str, default=None,
  197. help="Path to training data in a text file format",
  198. required=True)
  199. parser.add_argument("--seq_length", type=int, default=25,
  200. help="One training example sequence length")
  201. parser.add_argument("--batch_size", type=int, default=1,
  202. help="Training batch size")
  203. parser.add_argument("--iters_to_report", type=int, default=500,
  204. help="How often to report loss and generate text")
  205. parser.add_argument("--hidden_size", type=int, default=100,
  206. help="Dimension of the hidden representation")
  207. parser.add_argument("--gpu", action="store_true",
  208. help="If set, training is going to use GPU 0")
  209. args = parser.parse_args()
  210. device = core.DeviceOption(
  211. workspace.GpuDeviceType if args.gpu else caffe2_pb2.CPU, 0)
  212. with core.DeviceScope(device):
  213. model = CharRNN(args)
  214. model.CreateModel()
  215. model.TrainModel()
  216. if __name__ == '__main__':
  217. workspace.GlobalInit(['caffe2', '--caffe2_log_level=2'])
  218. main()