numa_benchmark.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from caffe2.python import core, workspace
  2. from caffe2.proto import caffe2_pb2
  3. import time
  4. SHAPE_LEN = 4096
  5. NUM_ITER = 1000
  6. GB = 1024 * 1024 * 1024
  7. NUM_REPLICAS = 48
  8. def build_net(net_name, cross_socket):
  9. init_net = core.Net(net_name + "_init")
  10. init_net.Proto().type = "async_scheduling"
  11. numa_device_option = caffe2_pb2.DeviceOption()
  12. numa_device_option.device_type = caffe2_pb2.CPU
  13. numa_device_option.numa_node_id = 0
  14. for replica_id in range(NUM_REPLICAS):
  15. init_net.XavierFill([], net_name + "/input_blob_" + str(replica_id),
  16. shape=[SHAPE_LEN, SHAPE_LEN], device_option=numa_device_option)
  17. net = core.Net(net_name)
  18. net.Proto().type = "async_scheduling"
  19. if cross_socket:
  20. numa_device_option.numa_node_id = 1
  21. for replica_id in range(NUM_REPLICAS):
  22. net.Copy(net_name + "/input_blob_" + str(replica_id),
  23. net_name + "/output_blob_" + str(replica_id),
  24. device_option=numa_device_option)
  25. return init_net, net
  26. def main():
  27. assert workspace.IsNUMAEnabled() and workspace.GetNumNUMANodes() >= 2
  28. single_init, single_net = build_net("single_net", False)
  29. cross_init, cross_net = build_net("cross_net", True)
  30. workspace.CreateNet(single_init)
  31. workspace.RunNet(single_init.Name())
  32. workspace.CreateNet(cross_init)
  33. workspace.RunNet(cross_init.Name())
  34. workspace.CreateNet(single_net)
  35. workspace.CreateNet(cross_net)
  36. for _ in range(4):
  37. t = time.time()
  38. workspace.RunNet(single_net.Name(), NUM_ITER)
  39. dt = time.time() - t
  40. print("Single socket time:", dt)
  41. single_bw = 4 * SHAPE_LEN * SHAPE_LEN * NUM_REPLICAS * NUM_ITER / dt / GB
  42. print("Single socket BW: {} GB/s".format(single_bw))
  43. t = time.time()
  44. workspace.RunNet(cross_net.Name(), NUM_ITER)
  45. dt = time.time() - t
  46. print("Cross socket time:", dt)
  47. cross_bw = 4 * SHAPE_LEN * SHAPE_LEN * NUM_REPLICAS * NUM_ITER / dt / GB
  48. print("Cross socket BW: {} GB/s".format(cross_bw))
  49. print("Single BW / Cross BW: {}".format(single_bw / cross_bw))
  50. if __name__ == '__main__':
  51. core.GlobalInit(["caffe2", "--caffe2_cpu_numa_enabled=1"])
  52. main()