| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import inspect
- import logging
- logging.basicConfig()
- log = logging.getLogger("ModuleRegister")
- log.setLevel(logging.DEBUG)
- MODULE_MAPS = []
- def registerModuleMap(module_map):
- MODULE_MAPS.append(module_map)
- log.info("ModuleRegister get modules from ModuleMap content: {}".
- format(inspect.getsource(module_map)))
- def constructTrainerClass(myTrainerClass, opts):
- log.info("ModuleRegister, myTrainerClass name is {}".
- format(myTrainerClass.__name__))
- log.info("ModuleRegister, myTrainerClass type is {}".
- format(type(myTrainerClass)))
- log.info("ModuleRegister, myTrainerClass dir is {}".
- format(dir(myTrainerClass)))
- myInitializeModelModule = getModule(opts['model']['model_name_py'])
- log.info("ModuleRegister, myInitializeModelModule dir is {}".
- format(dir(myInitializeModelModule)))
- myTrainerClass.init_model = myInitializeModelModule.init_model
- myTrainerClass.run_training_net = myInitializeModelModule.run_training_net
- myTrainerClass.fun_per_iter_b4RunNet = \
- myInitializeModelModule.fun_per_iter_b4RunNet
- myTrainerClass.fun_per_epoch_b4RunNet = \
- myInitializeModelModule.fun_per_epoch_b4RunNet
- myInputModule = getModule(opts['input']['input_name_py'])
- log.info("ModuleRegister, myInputModule {} dir is {}".
- format(opts['input']['input_name_py'], myInputModule.__name__))
- # Override input methods of the myTrainerClass class
- myTrainerClass.get_input_dataset = myInputModule.get_input_dataset
- myTrainerClass.get_model_input_fun = myInputModule.get_model_input_fun
- myTrainerClass.gen_input_builder_fun = myInputModule.gen_input_builder_fun
- # myForwardPassModule = GetForwardPassModule(opts)
- myForwardPassModule = getModule(opts['model']['forward_pass_py'])
- myTrainerClass.gen_forward_pass_builder_fun = \
- myForwardPassModule.gen_forward_pass_builder_fun
- myParamUpdateModule = getModule(opts['model']['parameter_update_py'])
- myTrainerClass.gen_param_update_builder_fun =\
- myParamUpdateModule.gen_param_update_builder_fun \
- if myParamUpdateModule is not None else None
- myOptimizerModule = getModule(opts['model']['optimizer_py'])
- myTrainerClass.gen_optimizer_fun = \
- myOptimizerModule.gen_optimizer_fun \
- if myOptimizerModule is not None else None
- myRendezvousModule = getModule(opts['model']['rendezvous_py'])
- myTrainerClass.gen_rendezvous_ctx = \
- myRendezvousModule.gen_rendezvous_ctx \
- if myRendezvousModule is not None else None
- # override output module
- myOutputModule = getModule(opts['output']['gen_output_py'])
- log.info("ModuleRegister, myOutputModule is {}".
- format(myOutputModule.__name__))
- myTrainerClass.fun_conclude_operator = myOutputModule.fun_conclude_operator
- myTrainerClass.assembleAllOutputs = myOutputModule.assembleAllOutputs
- return myTrainerClass
- def overrideAdditionalMethods(myTrainerClass, opts):
- log.info("B4 additional override myTrainerClass source {}".
- format(inspect.getsource(myTrainerClass)))
- # override any additional modules
- myAdditionalOverride = getModule(opts['model']['additional_override_py'])
- if myAdditionalOverride is not None:
- for funcName, funcValue in inspect.getmembers(myAdditionalOverride,
- inspect.isfunction):
- setattr(myTrainerClass, funcName, funcValue)
- log.info("Aft additional override myTrainerClass's source {}".
- format(inspect.getsource(myTrainerClass)))
- return myTrainerClass
- def getModule(moduleName):
- log.info("get module {} from MODULE_MAPS content {}".format(moduleName, str(MODULE_MAPS)))
- myModule = None
- for ModuleMap in MODULE_MAPS:
- log.info("iterate through MODULE_MAPS content {}".
- format(str(ModuleMap)))
- for name, obj in inspect.getmembers(ModuleMap):
- log.info("iterate through MODULE_MAPS a name {}".format(str(name)))
- if name == moduleName:
- log.info("AnyExp get module {} with source:{}".
- format(moduleName, inspect.getsource(obj)))
- myModule = obj
- return myModule
- return None
- def getClassFromModule(moduleName, className):
- myClass = None
- for ModuleMap in MODULE_MAPS:
- for name, obj in inspect.getmembers(ModuleMap):
- if name == moduleName:
- log.info("ModuleRegistry from module {} get class {} of source:{}".
- format(moduleName, className, inspect.getsource(obj)))
- myClass = getattr(obj, className)
- return myClass
- return None
|