model_device_test.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import numpy as np
  2. import unittest
  3. from caffe2.proto import caffe2_pb2
  4. from caffe2.python import (
  5. workspace,
  6. device_checker,
  7. test_util,
  8. model_helper,
  9. brew,
  10. )
  11. class TestMiniAlexNet(test_util.TestCase):
  12. def _MiniAlexNetNoDropout(self, order):
  13. # First, AlexNet using the cnn wrapper.
  14. model = model_helper.ModelHelper(name="alexnet")
  15. conv1 = brew.conv(
  16. model,
  17. "data",
  18. "conv1",
  19. 3,
  20. 16,
  21. 11,
  22. ("XavierFill", {}),
  23. ("ConstantFill", {}),
  24. stride=4,
  25. pad=0
  26. )
  27. relu1 = brew.relu(model, conv1, "relu1")
  28. norm1 = brew.lrn(model, relu1, "norm1", size=5, alpha=0.0001, beta=0.75)
  29. pool1 = brew.max_pool(model, norm1, "pool1", kernel=3, stride=2)
  30. conv2 = brew.group_conv(
  31. model,
  32. pool1,
  33. "conv2",
  34. 16,
  35. 32,
  36. 5,
  37. ("XavierFill", {}),
  38. ("ConstantFill", {"value": 0.1}),
  39. group=2,
  40. stride=1,
  41. pad=2
  42. )
  43. relu2 = brew.relu(model, conv2, "relu2")
  44. norm2 = brew.lrn(model, relu2, "norm2", size=5, alpha=0.0001, beta=0.75)
  45. pool2 = brew.max_pool(model, norm2, "pool2", kernel=3, stride=2)
  46. conv3 = brew.conv(
  47. model,
  48. pool2,
  49. "conv3",
  50. 32,
  51. 64,
  52. 3,
  53. ("XavierFill", {'std': 0.01}),
  54. ("ConstantFill", {}),
  55. pad=1
  56. )
  57. relu3 = brew.relu(model, conv3, "relu3")
  58. conv4 = brew.group_conv(
  59. model,
  60. relu3,
  61. "conv4",
  62. 64,
  63. 64,
  64. 3,
  65. ("XavierFill", {}),
  66. ("ConstantFill", {"value": 0.1}),
  67. group=2,
  68. pad=1
  69. )
  70. relu4 = brew.relu(model, conv4, "relu4")
  71. conv5 = brew.group_conv(
  72. model,
  73. relu4,
  74. "conv5",
  75. 64,
  76. 32,
  77. 3,
  78. ("XavierFill", {}),
  79. ("ConstantFill", {"value": 0.1}),
  80. group=2,
  81. pad=1
  82. )
  83. relu5 = brew.relu(model, conv5, "relu5")
  84. pool5 = brew.max_pool(model, relu5, "pool5", kernel=3, stride=2)
  85. fc6 = brew.fc(
  86. model, pool5, "fc6", 1152, 1024, ("XavierFill", {}),
  87. ("ConstantFill", {"value": 0.1})
  88. )
  89. relu6 = brew.relu(model, fc6, "relu6")
  90. fc7 = brew.fc(
  91. model, relu6, "fc7", 1024, 1024, ("XavierFill", {}),
  92. ("ConstantFill", {"value": 0.1})
  93. )
  94. relu7 = brew.relu(model, fc7, "relu7")
  95. fc8 = brew.fc(
  96. model, relu7, "fc8", 1024, 5, ("XavierFill", {}),
  97. ("ConstantFill", {"value": 0.0})
  98. )
  99. pred = brew.softmax(model, fc8, "pred")
  100. xent = model.LabelCrossEntropy([pred, "label"], "xent")
  101. loss = model.AveragedLoss([xent], ["loss"])
  102. model.AddGradientOperators([loss])
  103. return model
  104. def _testMiniAlexNet(self, order):
  105. # First, we get all the random initialization of parameters.
  106. model = self._MiniAlexNetNoDropout(order)
  107. workspace.ResetWorkspace()
  108. workspace.RunNetOnce(model.param_init_net)
  109. inputs = dict(
  110. [(str(name), workspace.FetchBlob(str(name))) for name in
  111. model.params]
  112. )
  113. if order == "NCHW":
  114. inputs["data"] = np.random.rand(4, 3, 227, 227).astype(np.float32)
  115. else:
  116. inputs["data"] = np.random.rand(4, 227, 227, 3).astype(np.float32)
  117. inputs["label"] = np.array([1, 2, 3, 4]).astype(np.int32)
  118. cpu_device = caffe2_pb2.DeviceOption()
  119. cpu_device.device_type = caffe2_pb2.CPU
  120. gpu_device = caffe2_pb2.DeviceOption()
  121. gpu_device.device_type = workspace.GpuDeviceType
  122. checker = device_checker.DeviceChecker(0.05, [cpu_device, gpu_device])
  123. ret = checker.CheckNet(
  124. model.net.Proto(),
  125. inputs,
  126. # The indices sometimes may be sensitive to small numerical
  127. # differences in the input, so we ignore checking them.
  128. ignore=['_pool1_idx', '_pool2_idx', '_pool5_idx']
  129. )
  130. self.assertEqual(ret, True)
  131. @unittest.skipIf(not workspace.has_gpu_support,
  132. "No GPU support. Skipping test.")
  133. def testMiniAlexNetNCHW(self):
  134. self._testMiniAlexNet("NCHW")
  135. # No Group convolution support for NHWC right now
  136. #@unittest.skipIf(not workspace.has_gpu_support,
  137. # "No GPU support. Skipping test.")
  138. #def testMiniAlexNetNHWC(self):
  139. # self._testMiniAlexNet("NHWC")
  140. if __name__ == '__main__':
  141. unittest.main()