experiment_util.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. ## @package experiment_util
  2. # Module caffe2.python.experiment_util
  3. import datetime
  4. import time
  5. import logging
  6. import socket
  7. import abc
  8. from collections import OrderedDict
  9. from future.utils import viewkeys, viewvalues
  10. '''
  11. Utilities for logging experiment run stats, such as accuracy
  12. and loss over time for different runs. Runtime arguments are stored
  13. in the log.
  14. Optionally, ModelTrainerLog calls out to a logger to log to
  15. an external log destination.
  16. '''
  17. class ExternalLogger(object):
  18. __metaclass__ = abc.ABCMeta
  19. @abc.abstractmethod
  20. def set_runtime_args(self, runtime_args):
  21. """
  22. Set runtime arguments for the logger.
  23. runtime_args: dict of runtime arguments.
  24. """
  25. raise NotImplementedError(
  26. 'Must define set_runtime_args function to use this base class'
  27. )
  28. @abc.abstractmethod
  29. def log(self, log_dict):
  30. """
  31. log a dict of key/values to an external destination
  32. log_dict: input dict
  33. """
  34. raise NotImplementedError(
  35. 'Must define log function to use this base class'
  36. )
  37. class ModelTrainerLog():
  38. def __init__(self, expname, runtime_args, external_loggers=None):
  39. now = datetime.datetime.fromtimestamp(time.time())
  40. self.experiment_id = \
  41. "{}_{}".format(expname, now.strftime('%Y%m%d_%H%M%S'))
  42. self.filename = "{}.log".format(self.experiment_id)
  43. self.logstr("# %s" % str(runtime_args))
  44. self.headers = None
  45. self.start_time = time.time()
  46. self.last_time = self.start_time
  47. self.last_input_count = 0
  48. self.external_loggers = None
  49. if external_loggers is not None:
  50. self.external_loggers = external_loggers
  51. if not isinstance(runtime_args, dict):
  52. runtime_args = dict(vars(runtime_args))
  53. runtime_args['experiment_id'] = self.experiment_id
  54. runtime_args['hostname'] = socket.gethostname()
  55. for logger in self.external_loggers:
  56. logger.set_runtime_args(runtime_args)
  57. else:
  58. self.external_loggers = []
  59. def logstr(self, str):
  60. with open(self.filename, "a") as f:
  61. f.write(str + "\n")
  62. f.close()
  63. logging.getLogger("experiment_logger").info(str)
  64. def log(self, input_count, batch_count, additional_values):
  65. logdict = OrderedDict()
  66. delta_t = time.time() - self.last_time
  67. delta_count = input_count - self.last_input_count
  68. self.last_time = time.time()
  69. self.last_input_count = input_count
  70. logdict['time_spent'] = delta_t
  71. logdict['cumulative_time_spent'] = time.time() - self.start_time
  72. logdict['input_count'] = delta_count
  73. logdict['cumulative_input_count'] = input_count
  74. logdict['cumulative_batch_count'] = batch_count
  75. if delta_t > 0:
  76. logdict['inputs_per_sec'] = delta_count / delta_t
  77. else:
  78. logdict['inputs_per_sec'] = 0.0
  79. for k in sorted(viewkeys(additional_values)):
  80. logdict[k] = additional_values[k]
  81. # Write the headers if they are not written yet
  82. if self.headers is None:
  83. self.headers = list(viewkeys(logdict))
  84. self.logstr(",".join(self.headers))
  85. self.logstr(",".join(str(v) for v in viewvalues(logdict)))
  86. for logger in self.external_loggers:
  87. try:
  88. logger.log(logdict)
  89. except Exception as e:
  90. logging.warning(
  91. "Failed to call ExternalLogger: {}".format(e), e)