device_checker.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. ## @package device_checker
  2. # Module caffe2.python.device_checker
  3. import numpy as np
  4. import copy
  5. from caffe2.python import workspace
  6. from caffe2.python.core import InferOpBlobDevicesAsDict
  7. from future.utils import viewitems
  8. class DeviceChecker(object):
  9. """A device checker in Python to check consistency across multiple devices.
  10. This is not the most efficient way to check devices, as the Python interface
  11. will involve a lot of copies back and forth operations. Use at your own risk.
  12. """
  13. def __init__(self, threshold, device_options):
  14. self._threshold = threshold
  15. self._device_options = device_options
  16. def CheckSimple(self, op, inputs, outputs_to_check,
  17. input_device_options=None):
  18. """Checks the operator with different device implementations.
  19. Inputs:
  20. op: the operator to be checked.
  21. inputs: the input data in numpy arrays.
  22. outputs_to_check: the outputs to check between devices.
  23. input_device_options: a mapping from input name to a device to use
  24. (instead of self._device_options)
  25. Outputs:
  26. boolean: True if it passes, False if it does not pass.
  27. """
  28. op = copy.deepcopy(op)
  29. # Entering the checker workspace
  30. old_ws_name = workspace.CurrentWorkspace()
  31. results = []
  32. workspace.SwitchWorkspace("_device_check_", True)
  33. for i, device_option in enumerate(self._device_options):
  34. op.device_option.CopyFrom(device_option)
  35. _input_device_options = input_device_options or \
  36. InferOpBlobDevicesAsDict(op)[0]
  37. print(_input_device_options)
  38. for i, arr in enumerate(inputs):
  39. workspace.FeedBlob(
  40. op.input[i], np.array(arr),
  41. _input_device_options.get(op.input[i], device_option)
  42. )
  43. workspace.RunOperatorOnce(op)
  44. results.append(
  45. [workspace.FetchBlob(op.output[idx])
  46. for idx in outputs_to_check])
  47. # Everything is done, reset the workspace.
  48. workspace.ResetWorkspace()
  49. # After running on all devices, check correctness
  50. success = True
  51. for i in range(1, len(self._device_options)):
  52. for j in range(len(outputs_to_check)):
  53. x = results[i][j]
  54. y = results[0][j]
  55. if not np.allclose(x, y,
  56. atol=self._threshold, rtol=self._threshold):
  57. print('Failure in checking device option {}'
  58. ' and output {}. The outputs are:'
  59. .format(i, op.output[outputs_to_check[j]]))
  60. print(x.flatten())
  61. print(y.flatten())
  62. print(np.max(np.abs(x - y)))
  63. success = False
  64. # else:
  65. # print ('Passed device pair (0, %d), %s %s' %
  66. # (i, outputs_to_check[j], y.shape))
  67. workspace.SwitchWorkspace(old_ws_name)
  68. return success
  69. def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None):
  70. """Checks a network by inspecting all of its intermediate results, and
  71. see if things match.
  72. """
  73. if inputs is None:
  74. inputs = {}
  75. if ignore is None:
  76. ignore = set()
  77. old_ws_name = workspace.CurrentWorkspace()
  78. results = []
  79. if blobs_to_check is None:
  80. blobs_to_check = sum([list(op.output) for op in net.op], [])
  81. blobs_to_check = [b for b in blobs_to_check if b not in ignore]
  82. workspace.SwitchWorkspace("_device_check_", True)
  83. for device_option in self._device_options:
  84. for name, arr in viewitems(inputs):
  85. # print 'feeding', name
  86. workspace.FeedBlob(name, arr, device_option)
  87. for op in net.op:
  88. op.device_option.CopyFrom(device_option)
  89. workspace.RunNetOnce(net)
  90. results.append(
  91. [workspace.FetchBlob(name) for name in blobs_to_check]
  92. )
  93. # After running on all devices, check correctness
  94. success = True
  95. for i in range(1, len(results)):
  96. for j in range(len(blobs_to_check)):
  97. x = results[i][j]
  98. y = results[0][j]
  99. if not np.allclose(x, y,
  100. atol=self._threshold, rtol=self._threshold):
  101. print('Failure in checking device option {}'
  102. ' and output {}. The outputs are:'
  103. .format(i, blobs_to_check[j]))
  104. print(x.flatten())
  105. print(y.flatten())
  106. print(np.max(np.abs(x - y)))
  107. success = False
  108. # else:
  109. # print ('Passed device pair (%d, %d), %s %s: %s' %
  110. # (i, j, blobs_to_check[j], y.shape,
  111. # str(y.flatten())))
  112. workspace.SwitchWorkspace(old_ws_name)
  113. return success