ModuleRegister.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import inspect
  2. import logging
  3. logging.basicConfig()
  4. log = logging.getLogger("ModuleRegister")
  5. log.setLevel(logging.DEBUG)
  6. MODULE_MAPS = []
  7. def registerModuleMap(module_map):
  8. MODULE_MAPS.append(module_map)
  9. log.info("ModuleRegister get modules from ModuleMap content: {}".
  10. format(inspect.getsource(module_map)))
  11. def constructTrainerClass(myTrainerClass, opts):
  12. log.info("ModuleRegister, myTrainerClass name is {}".
  13. format(myTrainerClass.__name__))
  14. log.info("ModuleRegister, myTrainerClass type is {}".
  15. format(type(myTrainerClass)))
  16. log.info("ModuleRegister, myTrainerClass dir is {}".
  17. format(dir(myTrainerClass)))
  18. myInitializeModelModule = getModule(opts['model']['model_name_py'])
  19. log.info("ModuleRegister, myInitializeModelModule dir is {}".
  20. format(dir(myInitializeModelModule)))
  21. myTrainerClass.init_model = myInitializeModelModule.init_model
  22. myTrainerClass.run_training_net = myInitializeModelModule.run_training_net
  23. myTrainerClass.fun_per_iter_b4RunNet = \
  24. myInitializeModelModule.fun_per_iter_b4RunNet
  25. myTrainerClass.fun_per_epoch_b4RunNet = \
  26. myInitializeModelModule.fun_per_epoch_b4RunNet
  27. myInputModule = getModule(opts['input']['input_name_py'])
  28. log.info("ModuleRegister, myInputModule {} dir is {}".
  29. format(opts['input']['input_name_py'], myInputModule.__name__))
  30. # Override input methods of the myTrainerClass class
  31. myTrainerClass.get_input_dataset = myInputModule.get_input_dataset
  32. myTrainerClass.get_model_input_fun = myInputModule.get_model_input_fun
  33. myTrainerClass.gen_input_builder_fun = myInputModule.gen_input_builder_fun
  34. # myForwardPassModule = GetForwardPassModule(opts)
  35. myForwardPassModule = getModule(opts['model']['forward_pass_py'])
  36. myTrainerClass.gen_forward_pass_builder_fun = \
  37. myForwardPassModule.gen_forward_pass_builder_fun
  38. myParamUpdateModule = getModule(opts['model']['parameter_update_py'])
  39. myTrainerClass.gen_param_update_builder_fun =\
  40. myParamUpdateModule.gen_param_update_builder_fun \
  41. if myParamUpdateModule is not None else None
  42. myOptimizerModule = getModule(opts['model']['optimizer_py'])
  43. myTrainerClass.gen_optimizer_fun = \
  44. myOptimizerModule.gen_optimizer_fun \
  45. if myOptimizerModule is not None else None
  46. myRendezvousModule = getModule(opts['model']['rendezvous_py'])
  47. myTrainerClass.gen_rendezvous_ctx = \
  48. myRendezvousModule.gen_rendezvous_ctx \
  49. if myRendezvousModule is not None else None
  50. # override output module
  51. myOutputModule = getModule(opts['output']['gen_output_py'])
  52. log.info("ModuleRegister, myOutputModule is {}".
  53. format(myOutputModule.__name__))
  54. myTrainerClass.fun_conclude_operator = myOutputModule.fun_conclude_operator
  55. myTrainerClass.assembleAllOutputs = myOutputModule.assembleAllOutputs
  56. return myTrainerClass
  57. def overrideAdditionalMethods(myTrainerClass, opts):
  58. log.info("B4 additional override myTrainerClass source {}".
  59. format(inspect.getsource(myTrainerClass)))
  60. # override any additional modules
  61. myAdditionalOverride = getModule(opts['model']['additional_override_py'])
  62. if myAdditionalOverride is not None:
  63. for funcName, funcValue in inspect.getmembers(myAdditionalOverride,
  64. inspect.isfunction):
  65. setattr(myTrainerClass, funcName, funcValue)
  66. log.info("Aft additional override myTrainerClass's source {}".
  67. format(inspect.getsource(myTrainerClass)))
  68. return myTrainerClass
  69. def getModule(moduleName):
  70. log.info("get module {} from MODULE_MAPS content {}".format(moduleName, str(MODULE_MAPS)))
  71. myModule = None
  72. for ModuleMap in MODULE_MAPS:
  73. log.info("iterate through MODULE_MAPS content {}".
  74. format(str(ModuleMap)))
  75. for name, obj in inspect.getmembers(ModuleMap):
  76. log.info("iterate through MODULE_MAPS a name {}".format(str(name)))
  77. if name == moduleName:
  78. log.info("AnyExp get module {} with source:{}".
  79. format(moduleName, inspect.getsource(obj)))
  80. myModule = obj
  81. return myModule
  82. return None
  83. def getClassFromModule(moduleName, className):
  84. myClass = None
  85. for ModuleMap in MODULE_MAPS:
  86. for name, obj in inspect.getmembers(ModuleMap):
  87. if name == moduleName:
  88. log.info("ModuleRegistry from module {} get class {} of source:{}".
  89. format(moduleName, className, inspect.getsource(obj)))
  90. myClass = getattr(obj, className)
  91. return myClass
  92. return None