embedding_generation_benchmark.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. ## @package embedding_generation_benchmark
  2. # Module caffe2.python.embedding_generation_benchmark
  3. from caffe2.proto import caffe2_pb2
  4. from caffe2.python import workspace, core, utils, model_helper
  5. import argparse
  6. import numpy as np
  7. import time
  8. import logging
  9. logging.basicConfig()
  10. log = logging.getLogger("embedding_generation_benchmark")
  11. log.setLevel(logging.DEBUG)
  12. def generate_data(T, batch_size, max_seq_length):
  13. '''
  14. Fill a queue with input data
  15. '''
  16. log.info("Generating T={} batches".format(T))
  17. generate_input_init_net = core.Net('generate_input_init')
  18. queue = generate_input_init_net.CreateBlobsQueue(
  19. [], "inputqueue", num_blobs=1, capacity=T,
  20. )
  21. workspace.RunNetOnce(generate_input_init_net)
  22. generate_input_net = core.Net('generate_input')
  23. generate_input_net.EnqueueBlobs([queue, "scratch"], ["scratch"])
  24. np.random.seed(2603)
  25. for t in range(T):
  26. if (t % (max(10, T // 10)) == 0):
  27. log.info("Generating data {}/{}".format(t, T))
  28. X = np.tile(np.arange(max_seq_length), [batch_size, 1]).transpose()
  29. workspace.FeedBlob("scratch", X)
  30. workspace.RunNetOnce(generate_input_net.Proto())
  31. log.info("Finished data generation")
  32. return queue
  33. def generate_embedding_table(vocab_size, embedding_size):
  34. log.info("Generating embedding table with dimensions {}"
  35. .format([vocab_size, embedding_size]))
  36. generate_table_net = core.Net('generate_table')
  37. table = generate_table_net.GaussianFill(
  38. [],
  39. ['embedding_table'],
  40. shape=[vocab_size, embedding_size],
  41. )
  42. workspace.RunNetOnce(generate_table_net)
  43. return table
  44. def create_model(args, queue, embedding_table, embedding_size):
  45. model = model_helper.ModelHelper(name='embedding_generation_bench')
  46. input_blob = model.net.DequeueBlobs(queue, 'input_data')
  47. if args.implementation == 'sinusoid':
  48. model.net.SinusoidPositionEncoding(
  49. [input_blob],
  50. ['output'],
  51. embedding_size=embedding_size
  52. )
  53. else:
  54. model.net.Gather(
  55. [embedding_table, input_blob],
  56. ['output'],
  57. )
  58. return model
  59. def Caffe2EmbeddingGeneration(args):
  60. T = args.data_size // args.batch_size
  61. queue = generate_data(T, args.batch_size, args.seq_length)
  62. embedding_table = None
  63. if args.implementation == 'table':
  64. embedding_table = generate_embedding_table(
  65. args.seq_length,
  66. args.embedding_size,
  67. )
  68. model = create_model(args, queue, embedding_table, args.embedding_size)
  69. workspace.RunNetOnce(model.param_init_net)
  70. workspace.CreateNet(model.net)
  71. start_time = time.time()
  72. num_iters = T
  73. total_iters = 0
  74. # Run the Benchmark
  75. log.info("------ Warming up ------")
  76. workspace.RunNet(model.net.Proto().name)
  77. log.info("------ Starting benchmark ------")
  78. start_time = time.time()
  79. last_time = time.time()
  80. for iteration in range(1, num_iters, args.iters_to_report):
  81. iters_once = min(args.iters_to_report, num_iters - iteration)
  82. total_iters += iters_once
  83. workspace.RunNet(model.net.Proto().name, iters_once)
  84. new_time = time.time()
  85. log.info(
  86. "Iter: {} / {}. Embeddings Generated Per Second: {}k.".format(
  87. iteration,
  88. num_iters,
  89. (iters_once * args.batch_size * args.seq_length) /
  90. (new_time - last_time) // 100 / 10,
  91. )
  92. )
  93. last_time = new_time
  94. total_per_sec = (num_iters - 1) * args.batch_size * args.seq_length
  95. total_per_sec = total_per_sec / (time.time() - start_time) // 100 / 10
  96. log.info("Done. Total embeddings generated per second " +
  97. "excluding 1st iteration: {}k".format(total_per_sec))
  98. return time.time() - start_time
  99. @utils.debug
  100. def Benchmark(args):
  101. return Caffe2EmbeddingGeneration(args)
  102. def GetArgumentParser():
  103. parser = argparse.ArgumentParser(
  104. description="Embedding generation benchmark."
  105. )
  106. parser.add_argument(
  107. "--embedding_size",
  108. type=int,
  109. default=512,
  110. help="Embedding size",
  111. )
  112. parser.add_argument(
  113. "--batch_size",
  114. type=int,
  115. default=16,
  116. help="The batch size."
  117. )
  118. parser.add_argument(
  119. "--data_size",
  120. type=int,
  121. default=10000,
  122. help="Number of sequences to generate"
  123. )
  124. parser.add_argument(
  125. "--seq_length",
  126. type=int,
  127. default=128,
  128. help="Max sequence length"
  129. )
  130. parser.add_argument(
  131. "--iters_to_report",
  132. type=int,
  133. default=20,
  134. help="Number of iterations to report progress"
  135. )
  136. parser.add_argument(
  137. "--implementation",
  138. type=str,
  139. default="sinusoid",
  140. help="'table' or 'sinusoid'",
  141. )
  142. return parser
  143. if __name__ == '__main__':
  144. args, extra_args = GetArgumentParser().parse_known_args()
  145. workspace.GlobalInit([
  146. 'caffe2',
  147. '--caffe2_log_level=0',
  148. '--caffe2_print_blob_sizes_at_exit=0'] + extra_args)
  149. device = core.DeviceOption(caffe2_pb2.CPU)
  150. with core.DeviceScope(device):
  151. Benchmark(args)