optimizer.py 77 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259
  1. # @package optimizer
  2. # Module caffe2.python.optimizer
  3. import copy
  4. import logging
  5. from collections import defaultdict, namedtuple
  6. import numpy as np
  7. from caffe2.proto import caffe2_pb2
  8. from caffe2.python import core, scope, utils, workspace
  9. from caffe2.python.modeling import parameter_info
  10. from past.builtins import basestring
  11. _LEARNING_RATE_INJECTION = "lr_injection"
  12. AuxOptimizerParams = namedtuple("AuxOptimizerParams", ["local", "shared"])
  13. _optimizer_instance_count = defaultdict(int)
  14. FP16_ENGINES = ["SIMD_Q_FP16", "SIMD_Q_STOC_FP16", "SIMD_Q_STOC_MKL_FP16"]
  15. logger = logging.getLogger(__name__)
  16. def reset_optimizer_instance_count():
  17. """
  18. This function clears the _optimizer_instance_count. And keeps it
  19. empty. This functionality is needed in some situations where
  20. optimizer instance count might not reset even though the workplace is reset.
  21. """
  22. _optimizer_instance_count.clear()
  23. class Optimizer(object):
  24. def __init__(self):
  25. self._aux_params = AuxOptimizerParams(local=[], shared=[])
  26. self._instance_num = _optimizer_instance_count[self.__class__.__name__]
  27. _optimizer_instance_count[self.__class__.__name__] += 1
  28. self._lr_multiplier = None
  29. self._local_lr_multiplier = None
  30. self._local_lr_multiplier_on_gpu = False
  31. """
  32. Adds optimization operators to the net for given parameter and its gradient
  33. Parameter is specified by either 'param' being a ParameterInfo object.
  34. In this case param.grad has to be set
  35. Or by 'param' being a BlobReference and 'grad' being a BlobReference for its
  36. gradient.
  37. """
  38. def __call__(self, net, param_init_net, param, grad=None):
  39. if grad is None:
  40. assert isinstance(
  41. param, parameter_info.ParameterInfo
  42. ), "Expected parameter to be of type ParameterInfo, got {}".format(param)
  43. assert param.grad is not None
  44. else:
  45. if isinstance(param, basestring):
  46. param = core.BlobReference(param)
  47. param = parameter_info.ParameterInfo(param_id=None, param=param, grad=grad)
  48. self._run(net, param_init_net, param)
  49. def _run(self, net, param_init_net, param_info):
  50. raise Exception("Not Implemented")
  51. def get_cpu_blob_name(self, base_str, node_name=""):
  52. classname = self.__class__.__name__
  53. return "%s_%d_%s%s_cpu" % (classname, self._instance_num, base_str, node_name)
  54. def get_gpu_blob_name(self, base_str, gpu_id, node_name):
  55. classname = self.__class__.__name__
  56. return "%s_%d_%s%s_gpu%d" % (
  57. classname,
  58. self._instance_num,
  59. base_str,
  60. node_name,
  61. gpu_id,
  62. )
  63. @property
  64. def attributes(self):
  65. # return a dict that contains attributes related to init args only
  66. attr = copy.deepcopy(self.__dict__)
  67. del attr["_instance_num"]
  68. return attr
  69. def make_unique_blob_name(self, base_str):
  70. """
  71. Returns a blob name that will be unique to the current device
  72. and optimizer instance.
  73. """
  74. current_scope = scope.CurrentDeviceScope()
  75. if current_scope is None:
  76. return self.get_cpu_blob_name(base_str)
  77. if core.IsGPUDeviceType(current_scope.device_type):
  78. return self.get_gpu_blob_name(
  79. base_str, current_scope.device_id, current_scope.node_name
  80. )
  81. else:
  82. return self.get_cpu_blob_name(base_str, current_scope.node_name)
  83. def build_lr(
  84. self,
  85. net,
  86. param_init_net,
  87. base_learning_rate,
  88. learning_rate_blob=None,
  89. policy="fixed",
  90. iter_val=0,
  91. **kwargs
  92. ):
  93. if learning_rate_blob is None:
  94. learning_rate_blob = self.make_unique_blob_name("lr")
  95. iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
  96. if not net.BlobIsDefined(learning_rate_blob):
  97. # There is one interesting thing here: since we are minimizing, we are
  98. # doing "descent" so the learning rate is set to be negative.
  99. lr = net.LearningRate(
  100. [iteration],
  101. learning_rate_blob,
  102. base_lr=-base_learning_rate,
  103. policy=policy,
  104. **kwargs
  105. )
  106. else:
  107. lr = net.GetBlobRef(learning_rate_blob)
  108. if self._lr_multiplier is not None:
  109. lr_multiplier = net.CopyFromCPUInput(
  110. self._lr_multiplier, self.make_unique_blob_name("lr_multiplier")
  111. )
  112. lr = net.Mul(
  113. [lr, lr_multiplier],
  114. self.make_unique_blob_name("scaled_lr"),
  115. broadcast=1,
  116. )
  117. if self._local_lr_multiplier is not None:
  118. current_scope = scope.CurrentDeviceScope()
  119. if (
  120. current_scope is not None
  121. and core.IsGPUDeviceType(current_scope.device_type)
  122. and not self._local_lr_multiplier_on_gpu
  123. ):
  124. local_lr_multiplier = net.CopyFromCPUInput(
  125. self._local_lr_multiplier,
  126. self.make_unique_blob_name("local_lr_multiplier"),
  127. )
  128. else:
  129. local_lr_multiplier = self._local_lr_multiplier
  130. lr = net.Mul(
  131. [lr, local_lr_multiplier],
  132. self.make_unique_blob_name("local_scaled_lr"),
  133. broadcast=1,
  134. )
  135. return lr, iteration
  136. def add_lr_multiplier(self, lr_multiplier):
  137. """
  138. Set the global learning rate multiplier. If a multiplier already
  139. existed, this will overwrite the existing multiplier. The multiplier is
  140. used for all future calls to _run(), unless it is overwritten.
  141. """
  142. self._lr_multiplier = lr_multiplier
  143. def _add_local_lr_multiplier(self, local_lr_multiplier, is_gpu_blob=False):
  144. """
  145. Set the local learning rate multiplier. This local multiplier is
  146. multiplied with the global learning rate multiplier if it exists. As
  147. with the global learning rate multiplier, this multiplier will be
  148. used for all future calls to _run(), so please call
  149. _clear_local_lr_multiplier() at the beginning of the optimizer's _run()
  150. before optionally calling this function.
  151. """
  152. self._local_lr_multiplier = local_lr_multiplier
  153. self._local_lr_multiplier_on_gpu = is_gpu_blob
  154. def _clear_local_lr_multiplier(self):
  155. self._local_lr_multiplier = None
  156. self._local_lr_multiplier_on_gpu = False
  157. @staticmethod
  158. def dedup(net, sparse_dedup_aggregator, grad):
  159. assert isinstance(
  160. grad, core.GradientSlice
  161. ), "Dedup only works for sparse gradient, got {}".format(grad)
  162. if sparse_dedup_aggregator:
  163. return net.DeduplicateGradientSlices(
  164. grad, aggregator=sparse_dedup_aggregator
  165. )
  166. else:
  167. return grad
  168. def get_auxiliary_parameters(self):
  169. """Returns a list of auxiliary parameters.
  170. Returns:
  171. aux_params: A namedtuple, AuxParams.
  172. aux_params.local stores a list of blobs. Each blob is a local
  173. auxiliary parameter. A local auxiliary parameter is a parameter in
  174. parallel to a learning rate parameter. Take adagrad as an example,
  175. the local auxiliary parameter is the squared sum parameter, because
  176. every learning rate has a squared sum associated with it.
  177. aux_params.shared also stores a list of blobs. Each blob is a shared
  178. auxiliary parameter. A shared auxiliary parameter is a parameter
  179. that is shared across all the learning rate parameters. Take adam as
  180. an example, the iteration parameter is a shared parameter, because
  181. all the learning rates share the same iteration parameter.
  182. """
  183. return self._aux_params
  184. # TODO(xlwang): In transfer learning, parameter initialized from pretrained
  185. # model might require a different learning rate than otherwise initialized.
  186. # To this end, here we implement a python solution where
  187. # `base_learning_rate` is scaled by `scale`, by calling
  188. # `scale_learning_rate`; Alternatively, we can achieve same effect by
  189. # rewriting the LearningRate operator in C++
  190. # Note that it is the responsibility of specific optimizer to decide what
  191. # logic should be used for `scale_learning_rate`
  192. def scale_learning_rate(self, *args, **kwargs):
  193. raise NotImplementedError(
  194. "Optimizer Need to Implement `scale_learning_rate` method."
  195. )
  196. def create_lars_inputs(self, param_init_net, weight_decay, trust, lr_max):
  197. wd = param_init_net.ConstantFill(
  198. [], "weight_decay", shape=[1], value=weight_decay
  199. )
  200. trust = param_init_net.ConstantFill([], "trust", shape=[1], value=trust)
  201. lr_max = param_init_net.ConstantFill([], "lr_max", shape=[1], value=lr_max)
  202. return wd, trust, lr_max
  203. class SgdOptimizer(Optimizer):
  204. def __init__(
  205. self,
  206. base_learning_rate=0.01,
  207. policy="fixed",
  208. momentum=0.0,
  209. nesterov=True,
  210. sparse_dedup_aggregator=None,
  211. lars=None,
  212. **kwargs
  213. ):
  214. super(SgdOptimizer, self).__init__()
  215. self.base_learning_rate = base_learning_rate
  216. self.policy = policy
  217. self.momentum = momentum
  218. self.nesterov = nesterov
  219. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  220. self.lars = lars
  221. self.init_kwargs = kwargs
  222. def _run(self, net, param_init_net, param_info):
  223. param = param_info.blob
  224. grad = param_info.grad
  225. if self.base_learning_rate == 0:
  226. return
  227. assert (
  228. self.base_learning_rate > 0
  229. ), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
  230. self._clear_local_lr_multiplier()
  231. # TODO(zqq): support LARS for sparse parameters
  232. if self.lars is not None and not isinstance(grad, core.GradientSlice):
  233. assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
  234. self.lars
  235. )
  236. wd, trust, lr_max = self.create_lars_inputs(
  237. param_init_net, 0.0, 1.0, np.finfo(np.float32).max
  238. )
  239. lr_lars_multiplier = net.Lars(
  240. [param, grad, wd, trust, lr_max],
  241. self.make_unique_blob_name(str(param) + "_lars"),
  242. offset=self.lars,
  243. lr_min=0.0,
  244. )
  245. current_scope = scope.CurrentDeviceScope()
  246. self._add_local_lr_multiplier(
  247. lr_lars_multiplier,
  248. is_gpu_blob=(
  249. current_scope is not None
  250. and core.IsGPUDeviceType(current_scope.device_type)
  251. ),
  252. )
  253. # We need negative sign for LR when used directly with WeightedSum
  254. # below.
  255. lr_sign = -1 if self.momentum else 1
  256. lr, _ = self.build_lr(
  257. net,
  258. param_init_net,
  259. base_learning_rate=self.base_learning_rate * lr_sign,
  260. policy=self.policy,
  261. **(self.init_kwargs)
  262. )
  263. dev = scope.CurrentDeviceScope()
  264. if dev is None:
  265. dev = core.DeviceOption(caffe2_pb2.CPU)
  266. # Each GPU/CPU must have its own ONE blob, thus modify the name
  267. # to include device information.
  268. ONE = param_init_net.ConstantFill(
  269. [],
  270. "ONE_{}_{}{}".format(dev.device_type, dev.device_id, dev.node_name),
  271. shape=[1],
  272. value=1.0,
  273. )
  274. self._aux_params.shared.append(ONE)
  275. if self.momentum > 0:
  276. momentum_data = param_init_net.ConstantFill(
  277. param, str(param) + "_momentum", value=0.0
  278. )
  279. self._aux_params.local.append(momentum_data)
  280. if isinstance(grad, core.GradientSlice):
  281. grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
  282. if self.momentum > 0.0:
  283. net.SparseMomentumSGDUpdate(
  284. [grad.values, momentum_data, lr, param, grad.indices],
  285. [grad.values, momentum_data, param],
  286. momentum=self.momentum,
  287. nesterov=self.nesterov,
  288. )
  289. else:
  290. net.ScatterWeightedSum(
  291. [param, ONE, grad.indices, grad.values, lr], param
  292. )
  293. else:
  294. if self.momentum > 0.0:
  295. net.MomentumSGDUpdate(
  296. [grad, momentum_data, lr, param],
  297. [grad, momentum_data, param],
  298. momentum=self.momentum,
  299. nesterov=self.nesterov,
  300. )
  301. else:
  302. coeff = lr
  303. net.WeightedSum([param, ONE, grad, coeff], param)
  304. def scale_learning_rate(self, scale):
  305. self.base_learning_rate *= scale
  306. return
  307. class MultiPrecisionSgdOptimizer(SgdOptimizer):
  308. def __init__(
  309. self,
  310. base_learning_rate=0.1,
  311. momentum=0.0,
  312. policy="fixed",
  313. nesterov=True,
  314. sparse_dedup_aggregator=None,
  315. **kwargs
  316. ):
  317. super(MultiPrecisionSgdOptimizer, self).__init__(
  318. base_learning_rate=base_learning_rate,
  319. policy=policy,
  320. momentum=momentum,
  321. nesterov=nesterov,
  322. sparse_dedup_aggregator=sparse_dedup_aggregator,
  323. **kwargs
  324. )
  325. def _run(self, net, param_init_net, param_info):
  326. param = param_info.blob
  327. param_fp32 = (
  328. param_info.blob_copy[core.DataType.FLOAT]
  329. if param_info.blob_copy is not None
  330. else None
  331. )
  332. # If we have a straight fp32 parameter, run the base class
  333. if param_fp32 is None:
  334. return SgdOptimizer._run(self, net, param_init_net, param_info)
  335. grad = param_info.grad
  336. if self.base_learning_rate == 0:
  337. return
  338. assert (
  339. self.base_learning_rate > 0
  340. ), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
  341. lr, _ = self.build_lr(
  342. net,
  343. param_init_net,
  344. base_learning_rate=-self.base_learning_rate,
  345. policy=self.policy,
  346. **(self.init_kwargs)
  347. )
  348. momentum_data = param_init_net.ConstantFill(
  349. param_fp32, str(param) + "_momentum", value=0.0
  350. )
  351. self._aux_params.local.append(momentum_data)
  352. assert not isinstance(
  353. grad, core.GradientSlice
  354. ), "MultiPrecisionSgd does not support sparse gradients"
  355. # Copy gradient to fp32
  356. grad_fp32 = net.HalfToFloat(grad, grad + "_fp32")
  357. # update (fused) in fp32
  358. net.MomentumSGDUpdate(
  359. [grad_fp32, momentum_data, lr, param_fp32],
  360. [grad_fp32, momentum_data, param_fp32],
  361. momentum=self.momentum,
  362. nesterov=self.nesterov,
  363. )
  364. # Copy updated param back to fp16
  365. net.FloatToHalf(param_fp32, param)
  366. class FP16SgdOptimizer(SgdOptimizer):
  367. def __init__(
  368. self,
  369. base_learning_rate=0.1,
  370. momentum=0.0,
  371. policy="fixed",
  372. nesterov=True,
  373. weight_decay=0.0001,
  374. sparse_dedup_aggregator=None,
  375. **kwargs
  376. ):
  377. super(FP16SgdOptimizer, self).__init__(
  378. base_learning_rate=base_learning_rate,
  379. policy=policy,
  380. momentum=momentum,
  381. nesterov=nesterov,
  382. sparse_dedup_aggregator=sparse_dedup_aggregator,
  383. **kwargs
  384. )
  385. self.weight_decay = weight_decay
  386. def _run(self, net, param_init_net, param_info, fp32_update=False):
  387. fp32_update_flag = 0
  388. param_name = str(param_info.blob)
  389. # should only be triggered in FP16 training by SpatialBN, which
  390. # requires FP32 params in CuDNN.
  391. if param_name.find("spatbn") != -1:
  392. fp32_update = True
  393. if fp32_update:
  394. # doing a 32bit update
  395. # Have to assume param_info.blob is FP32 as there is no way
  396. # (that i currently know of) to query a blob's type in python
  397. fp32_update_flag = 1
  398. param = param_info.blob
  399. param_fp32 = param_info.blob
  400. else:
  401. if param_info.blob_copy is None:
  402. # doing a 32bit update
  403. # Have to assume param_info.blob is FP32 as there is no way
  404. # (that i currently know of) to query a blob's type in python
  405. fp32_update_flag = 1
  406. param = param_info.blob
  407. param_fp32 = param_info.blob
  408. else:
  409. if core.DataType.FLOAT in param_info.blob_copy:
  410. param = param_info.blob
  411. param_fp32 = param_info.blob_copy[core.DataType.FLOAT]
  412. elif core.DataType.FLOAT16 in param_info.blob_copy:
  413. param = param_info.blob_copy[core.DataType.FLOAT16]
  414. param_fp32 = param_info.blob
  415. else:
  416. AssertionError(
  417. "Unrecognized parameter format to be updated "
  418. "by FP16 Optimizer. Parameter: {}".format(param_info.name)
  419. )
  420. grad = param_info.grad
  421. if self.base_learning_rate == 0:
  422. return
  423. assert (
  424. self.base_learning_rate > 0
  425. ), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
  426. lr, _ = self.build_lr(
  427. net,
  428. param_init_net,
  429. base_learning_rate=-self.base_learning_rate,
  430. policy=self.policy,
  431. **(self.init_kwargs)
  432. )
  433. momentum_data_fp32 = param_init_net.ConstantFill(
  434. param_fp32, str(param) + "_momentum_fp32", value=0.0
  435. )
  436. momentum_data = param_init_net.FloatToHalf(
  437. momentum_data_fp32, str(param) + "_momentum"
  438. )
  439. self._aux_params.local.append(momentum_data)
  440. assert not isinstance(
  441. grad, core.GradientSlice
  442. ), "FP16Sgd does not support sparse gradients"
  443. if fp32_update_flag == 0:
  444. net.FP16MomentumSGDUpdate(
  445. [grad, momentum_data, lr, param],
  446. [grad, momentum_data, param],
  447. momentum=self.momentum,
  448. nesterov=self.nesterov,
  449. weight_decay=self.weight_decay,
  450. )
  451. else:
  452. # flag set to 1, therefore doing FP32 update
  453. net.FP32MomentumSGDUpdate(
  454. [grad, momentum_data_fp32, lr, param],
  455. [grad, momentum_data_fp32, param],
  456. momentum=self.momentum,
  457. nesterov=self.nesterov,
  458. weight_decay=self.weight_decay,
  459. )
  460. class WeightDecayBuilder(Optimizer):
  461. def __init__(self, weight_decay):
  462. self.weight_decay = weight_decay
  463. def _run(self, net, param_init_net, param_info):
  464. dev = scope.CurrentDeviceScope()
  465. if dev is None:
  466. dev = core.DeviceOption(caffe2_pb2.CPU)
  467. ONE = param_init_net.ConstantFill(
  468. [], "ONE_{}_{}".format(dev.device_type, dev.device_id), shape=[1], value=1.0
  469. )
  470. WD = param_init_net.ConstantFill(
  471. [],
  472. "wd_{}_{}".format(dev.device_type, dev.device_id),
  473. shape=[1],
  474. value=self.weight_decay,
  475. )
  476. if isinstance(param_info.grad, core.GradientSlice):
  477. raise ValueError("Weight decay does not yet support sparse gradients")
  478. else:
  479. net.WeightedSum(
  480. [param_info.grad, ONE, param_info.blob, WD], param_info.grad
  481. )
  482. class AdagradOptimizer(Optimizer):
  483. def __init__(
  484. self,
  485. alpha=0.01,
  486. epsilon=1e-4,
  487. decay=1,
  488. weight_decay=0.0,
  489. policy="fixed",
  490. sparse_dedup_aggregator=None,
  491. rowWise=False,
  492. engine="",
  493. lars=None,
  494. output_effective_lr=False,
  495. output_effective_lr_and_update=False,
  496. pruning_options=None,
  497. swa_options=None,
  498. ema_options=None,
  499. weight_scale=None,
  500. counter_halflife=-1,
  501. **kwargs
  502. ):
  503. super(AdagradOptimizer, self).__init__()
  504. self.alpha = alpha
  505. self.epsilon = epsilon
  506. self.decay = decay
  507. self.weight_decay = float(weight_decay)
  508. self.policy = policy
  509. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  510. self.rowWise = rowWise
  511. self.engine = engine
  512. self.lars = lars
  513. self.output_effective_lr = output_effective_lr
  514. self.output_effective_lr_and_update = output_effective_lr_and_update
  515. self.counter_halflife = counter_halflife
  516. self.init_kwargs = kwargs
  517. self.weight_scale = weight_scale
  518. self._process_pruning_options(pruning_options)
  519. self._process_swa_options(swa_options)
  520. self._process_ema_options(ema_options)
  521. def _process_swa_options(self, swa_options):
  522. self.swa_enabled = True if swa_options else False
  523. if self.swa_enabled:
  524. self.swa_avg_start_it = swa_options.get("swa_avg_start_it", None)
  525. self.swa_avg_end_it = swa_options.get("swa_avg_end_it", None)
  526. self.swa_feedback_start_it = swa_options.get("swa_feedback_start_it", None)
  527. self.swa_feedback_step = swa_options.get("swa_feedback_step", None)
  528. self.swa_feedback_end_it = swa_options.get("swa_feedback_end_it", None)
  529. def _process_ema_options(self, ema_options):
  530. self.ema_enabled = True if ema_options else False
  531. if self.ema_enabled:
  532. self.ema_start = ema_options.get("ema_start", None)
  533. self.ema_end = ema_options.get("ema_end", None)
  534. self.ema_step = ema_options.get("ema_step", None)
  535. self.ema_alpha = ema_options.get("ema_alpha", None)
  536. def _process_pruning_options(self, pruning_options):
  537. self.use_mask = False
  538. if pruning_options is None:
  539. pruning_options = {}
  540. else:
  541. assert isinstance(pruning_options, dict), (
  542. "pruning_options can only "
  543. "be provided as a dictionary, currently: {}".format(pruning_options)
  544. )
  545. self.mask_tensor = pruning_options.get("mask_tensor", None)
  546. self.mask_db_path = pruning_options.get("mask_db_path", None)
  547. self.mask_db_type = pruning_options.get("mask_db_type", None)
  548. self.mask_blob_name = pruning_options.get("mask_blob_name", None)
  549. self.prune_delays = pruning_options.get("prune_delays", [])
  550. self.prune_ratios = pruning_options.get("prune_ratios", [])
  551. self.prune_block_size = pruning_options.get("prune_block_size", 1)
  552. if self.mask_tensor is not None:
  553. assert (
  554. type(self.mask_tensor) is np.ndarray
  555. ), "mask_tensor must be a numpy array!"
  556. assert self.mask_db_path is None, (
  557. "mask can be provided through either a numpy array "
  558. "or a db path, not both"
  559. )
  560. assert self.mask_db_type is None, (
  561. "mask can be provided through either a numpy array "
  562. "or a db path, not both"
  563. )
  564. assert self.mask_blob_name is None, (
  565. "mask can be provided through either a numpy array "
  566. "or a db path, not both"
  567. )
  568. self.use_mask = True
  569. if self.mask_db_path is not None or self.mask_db_type is not None:
  570. assert self.mask_db_path is not None, (
  571. "when mask is provided through db, "
  572. "db path, db type, and blob name are all needed"
  573. )
  574. assert self.mask_db_type is not None, (
  575. "when mask is provided through db, "
  576. "db path, db type, and blob name are all needed"
  577. )
  578. assert self.mask_tensor is None, (
  579. "mask can be provided through either a numpy array "
  580. "or a db path, not both"
  581. )
  582. self.use_mask = True
  583. if self.prune_delays:
  584. assert self.prune_ratios is not None and len(self.prune_delays) == len(
  585. self.prune_ratios
  586. ), "Prune Delays and prune ratios should be of the same length"
  587. assert (
  588. self.mask_tensor is None
  589. ), "Mask Tensor should be None with prune ratios"
  590. assert (
  591. self.mask_db_path is None
  592. ), "Mask DB Path should be None with prune ratios"
  593. self.use_mask = True
  594. def _run(self, net, param_init_net, param_info):
  595. param = param_info.blob
  596. grad = param_info.grad
  597. if self.alpha <= 0:
  598. return
  599. self._clear_local_lr_multiplier()
  600. if self.lars is not None and not isinstance(grad, core.GradientSlice):
  601. assert (
  602. self.weight_decay == 0
  603. ), "weight decay is not implemented for LARS yet"
  604. assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
  605. self.lars
  606. )
  607. wd, trust, lr_max = self.create_lars_inputs(
  608. param_init_net, 0.0, 1.0, np.finfo(np.float32).max
  609. )
  610. lr_lars_multiplier = net.Lars(
  611. [param, grad, wd, trust, lr_max],
  612. self.make_unique_blob_name(str(param) + "_lars"),
  613. offset=self.lars,
  614. lr_min=0.0,
  615. )
  616. current_scope = scope.CurrentDeviceScope()
  617. self._add_local_lr_multiplier(
  618. lr_lars_multiplier,
  619. is_gpu_blob=(
  620. current_scope is not None
  621. and core.IsGPUDeviceType(current_scope.device_type)
  622. ),
  623. )
  624. lr, lr_iteration = self.build_lr(
  625. net,
  626. param_init_net,
  627. base_learning_rate=self.alpha,
  628. policy=self.policy,
  629. **(self.init_kwargs)
  630. )
  631. iteration = lr_iteration
  632. if self.counter_halflife > 0:
  633. self._aux_params.shared.append(iteration)
  634. if self.rowWise:
  635. logger.debug(
  636. "Using engine {} for rowWise Adagrad to train param {}".format(
  637. self.engine, param
  638. )
  639. )
  640. shapes, types = workspace.InferShapesAndTypes([param_init_net])
  641. if str(param) not in shapes:
  642. # Type/shape inference is not available for this param, fallback
  643. # on Shape/Slice logic
  644. shape = param_init_net.Shape(param, str(param) + "_shape")
  645. num_rows = param_init_net.Slice(
  646. [shape], str(shape) + "_numrows", starts=[0], ends=[1]
  647. )
  648. param_squared_sum = param_init_net.ConstantFill(
  649. num_rows,
  650. str(param) + "_avg_squared_sum",
  651. input_as_shape=1,
  652. value=0.0,
  653. )
  654. else:
  655. param_squared_sum = param_init_net.ConstantFill(
  656. [],
  657. str(param) + "_avg_squared_sum",
  658. shape=[shapes[str(param)][0]],
  659. value=0.0,
  660. )
  661. else:
  662. logger.debug(
  663. "Using engine {} for regular Adagrad to train param {}".format(
  664. self.engine, param
  665. )
  666. )
  667. if self.engine in FP16_ENGINES:
  668. assert (
  669. self.weight_decay == 0
  670. ), "weight decay is not tested for engine: {}".format(self.engine)
  671. shapes, types = workspace.InferShapesAndTypes([param_init_net])
  672. assert str(param) in shapes, shapes
  673. shape = shapes[str(param)]
  674. param_squared_sum = param_init_net.Float16ConstantFill(
  675. [], str(param) + "_squared_sum", value=0.0, shape=shape
  676. )
  677. else:
  678. param_squared_sum = param_init_net.ConstantFill(
  679. [param], str(param) + "_squared_sum", value=0.0
  680. )
  681. if self.use_mask is True:
  682. assert (
  683. self.weight_decay == 0
  684. ), "weight decay is not implemented for use_mask yet"
  685. if self.mask_tensor is not None:
  686. if not isinstance(grad, core.GradientSlice):
  687. mask_blob = param_init_net.GivenTensorFill(
  688. [],
  689. [str(param) + "_mask"],
  690. values=self.mask_tensor,
  691. shape=self.mask_tensor.shape,
  692. )
  693. else:
  694. self.mask_tensor = self.mask_tensor.astype(np.uint8)
  695. mask_blob = param_init_net.GivenTensorBoolFill(
  696. [],
  697. [str(param) + "_mask"],
  698. values=self.mask_tensor,
  699. shape=self.mask_tensor.shape,
  700. )
  701. mask_blob = param_init_net.Cast(mask_blob, to=core.DataType.UINT8)
  702. mask_changed_blob = param_init_net.ConstantFill(
  703. [],
  704. [str(param) + "_mask_changed_blob"],
  705. value=False,
  706. dtype=core.DataType.BOOL,
  707. shape=[1],
  708. )
  709. elif (
  710. self.mask_db_path is not None or self.mask_db_type is not None
  711. ): # mask is provided through a db file
  712. # if mask_blob_name is not given use the param name to derive mask name
  713. self.mask_blob_name = self.mask_blob_name or str(param) + "_mask"
  714. mask_blob = param_init_net.Load(
  715. [],
  716. self.mask_blob_name,
  717. db=self.mask_db_path,
  718. db_type=self.mask_db_type,
  719. absolute_path=True,
  720. )
  721. if isinstance(grad, core.GradientSlice):
  722. mask_changed_blob = param_init_net.ConstantFill(
  723. [],
  724. [str(param) + "_mask_changed_blob"],
  725. value=False,
  726. dtype=core.DataType.BOOL,
  727. shape=[1],
  728. )
  729. elif self.prune_delays:
  730. last_mask_updated_iter = param_init_net.ConstantFill(
  731. [],
  732. [str(param) + "_last_mask_updated_iter"],
  733. value=-1,
  734. dtype=core.DataType.INT64,
  735. shape=[1],
  736. )
  737. if isinstance(grad, core.GradientSlice):
  738. AssertionError(
  739. "Prune Delays and Prune Ratios are currently not supported"
  740. "for sparse operators"
  741. )
  742. else:
  743. mask_blob = param_init_net.GivenTensorFill(
  744. [],
  745. [str(param) + "_empty_mask"],
  746. values=[],
  747. dtype=core.DataType.FLOAT,
  748. shape=[0],
  749. )
  750. else:
  751. raise NotImplementedError(
  752. "If mask is used, it needs a numpy array or a db file or"
  753. "a delay iter needs to be provided"
  754. )
  755. self._aux_params.local.append(param_squared_sum)
  756. if self.counter_halflife > 0:
  757. shapes, types = workspace.InferShapesAndTypes([param_init_net])
  758. if str(param) not in shapes:
  759. shape = param_init_net.Shape(param, str(param) + "_shape")
  760. num_rows = param_init_net.Slice(
  761. [shape], str(shape) + "_numrows", starts=[0], ends=[1]
  762. )
  763. update_counter = param_init_net.ConstantFill(
  764. num_rows,
  765. str(param) + "_update_counter",
  766. input_as_shape=1,
  767. value=0.0,
  768. dtype=core.DataType.DOUBLE,
  769. )
  770. prev_update_iter = param_init_net.ConstantFill(
  771. num_rows,
  772. str(param) + "_prev_update_iter",
  773. input_as_shape=1,
  774. value=0,
  775. dtype=core.DataType.INT64,
  776. )
  777. else:
  778. update_counter = param_init_net.ConstantFill(
  779. [],
  780. str(param) + "_update_counter",
  781. shape=[shapes[str(param)][0]],
  782. value=0.0,
  783. dtype=core.DataType.DOUBLE,
  784. )
  785. prev_update_iter = param_init_net.ConstantFill(
  786. [],
  787. str(param) + "_prev_update_iter",
  788. shape=[shapes[str(param)][0]],
  789. value=0,
  790. dtype=core.DataType.INT64,
  791. )
  792. self._aux_params.local.append(update_counter)
  793. self._aux_params.local.append(prev_update_iter)
  794. if self.rowWise:
  795. assert isinstance(grad, core.GradientSlice), (
  796. "If SparseAdagrad with rowWise=True, gradient must be "
  797. "a gradientslice. PLease ensure that rowWise is not enabled "
  798. "for the dense Adagrad optimizer, as it is not supported."
  799. )
  800. shapes, _ = workspace.InferShapesAndTypes([param_init_net])
  801. param_shape = shapes[str(param)]
  802. weight_decay = 0.0
  803. if isinstance(grad, core.GradientSlice):
  804. if len(param_shape) == 1:
  805. weight_decay = 0.0
  806. logger.warn(
  807. "SKIPPING weight decay on 1d sparse param: {}.shape is {}".format(
  808. str(param), param_shape
  809. )
  810. )
  811. else:
  812. weight_decay = self.weight_decay
  813. else:
  814. # Skip weight decay for 1d parameters
  815. if len(param_shape) == 1:
  816. weight_decay = 0.0
  817. logger.warning(
  818. "SKIPPING weight decay on 1d dense param: {}.shape is {}".format(
  819. str(param), param_shape
  820. )
  821. )
  822. else:
  823. weight_decay = self.weight_decay
  824. logger.debug(
  825. "weight_decay for {} (shape:{}): {}".format(
  826. str(param), param_shape, weight_decay
  827. )
  828. )
  829. if isinstance(grad, core.GradientSlice):
  830. assert (
  831. self.decay == 1.0
  832. ), "Decay is not implemented for SparseAdagrad and must be set to 1"
  833. grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
  834. input_args = [param, param_squared_sum, grad.indices, grad.values, lr]
  835. output_args = [param, param_squared_sum]
  836. if self.rowWise:
  837. if self.use_mask is True:
  838. op = "MaskedRowWiseSparseAdagrad"
  839. assert (
  840. weight_decay == 0
  841. ), "weight decay is not implemented for {} yet".format(op)
  842. input_args += [mask_blob, mask_changed_blob]
  843. else:
  844. if self.counter_halflife > 0:
  845. input_args += [update_counter]
  846. op = "RowWiseSparseAdagrad"
  847. else:
  848. if self.use_mask is True:
  849. op = "MaskedSparseAdagrad"
  850. assert (
  851. weight_decay == 0
  852. ), "weight decay is not implemented for {} yet".format(op)
  853. input_args += [mask_blob, mask_changed_blob]
  854. else:
  855. op = "SparseAdagrad"
  856. logger.debug("using {} for {}".format(op, str(param)))
  857. if self.prune_delays:
  858. input_args += [lr_iteration, last_mask_updated_iter]
  859. output_args += [mask_blob, last_mask_updated_iter]
  860. if weight_decay > 0 and self.counter_halflife == -1:
  861. net.__getattr__(op)(
  862. input_args,
  863. output_args,
  864. epsilon=self.epsilon,
  865. weight_decay=weight_decay,
  866. engine=self.engine,
  867. )
  868. elif weight_decay > 0 and self.counter_halflife != -1:
  869. net.__getattr__(op)(
  870. input_args,
  871. output_args,
  872. epsilon=self.epsilon,
  873. weight_decay=weight_decay,
  874. engine=self.engine,
  875. counter_halflife=self.counter_halflife,
  876. )
  877. else:
  878. net.__getattr__(op)(
  879. input_args, output_args, epsilon=self.epsilon, engine=self.engine
  880. )
  881. if self.counter_halflife > 0:
  882. net.RowWiseCounter(
  883. [prev_update_iter, update_counter, grad.indices, iteration],
  884. [prev_update_iter, update_counter],
  885. counter_halflife=self.counter_halflife,
  886. )
  887. else:
  888. input_args = [param, param_squared_sum, grad, lr]
  889. output_args = [param, param_squared_sum]
  890. if self.output_effective_lr_and_update:
  891. assert (
  892. self.use_mask is False
  893. ), "MaskedAdagrad doesn't support outputting effective_lr_and_update"
  894. output_args.append(str(param) + "_effective_lr")
  895. output_args.append(str(param) + "_update")
  896. elif self.output_effective_lr:
  897. assert (
  898. self.use_mask is False
  899. ), "MaskedAdagrad doesn't support outputting effective_lr"
  900. output_args.append(str(param) + "_effective_lr")
  901. if self.use_mask is True:
  902. input_args += [mask_blob]
  903. if self.prune_delays:
  904. input_args += [lr_iteration, last_mask_updated_iter]
  905. output_args += [mask_blob, last_mask_updated_iter]
  906. if self.use_mask:
  907. assert (
  908. weight_decay == 0
  909. ), "weight decay is not implemented for use_mask yet"
  910. net.MaskedAdagrad(
  911. input_args,
  912. output_args,
  913. epsilon=self.epsilon,
  914. decay=float(self.decay),
  915. block_size=self.prune_block_size,
  916. delays=self.prune_delays,
  917. prune_ratios=self.prune_ratios,
  918. engine=self.engine,
  919. )
  920. else:
  921. if weight_decay > 0:
  922. net.Adagrad(
  923. input_args,
  924. output_args,
  925. epsilon=self.epsilon,
  926. decay=float(self.decay),
  927. weight_decay=weight_decay,
  928. engine=self.engine,
  929. )
  930. else:
  931. net.Adagrad(
  932. input_args,
  933. output_args,
  934. epsilon=self.epsilon,
  935. decay=float(self.decay),
  936. engine=self.engine,
  937. )
  938. if self.swa_enabled:
  939. param_swa = str(param) + "_swa"
  940. if not param_init_net.BlobIsDefined(param_swa):
  941. param_init_net.ConstantFill([param], param_swa, value=0.0)
  942. self._aux_params.local.append(param_swa)
  943. net.SWA(
  944. [param, param_swa, lr_iteration],
  945. [param, param_swa],
  946. avg_start=self.swa_avg_start_it,
  947. avg_end=self.swa_avg_end_it,
  948. feedback_start=self.swa_feedback_start_it,
  949. feedback_step=self.swa_feedback_step,
  950. feedback_end=self.swa_feedback_end_it,
  951. )
  952. if self.ema_enabled:
  953. param_ema = str(param) + "_ema"
  954. if not param_init_net.BlobIsDefined(param_ema):
  955. param_init_net.ConstantFill([param], param_ema, value=0.0)
  956. self._aux_params.local.append(param_ema)
  957. net.EMA(
  958. [param, param_ema, lr_iteration],
  959. [param, param_ema],
  960. ema_start=self.ema_start,
  961. ema_end=self.ema_end,
  962. ema_step=self.ema_step,
  963. ema_alpha=self.ema_alpha,
  964. )
  965. if self.weight_scale:
  966. net.WeightScale(
  967. [param, lr_iteration],
  968. [param],
  969. stepsize=self.weight_scale.stepsize,
  970. upper_bound_iter=self.weight_scale.upper_bound_iter,
  971. scale=float(self.weight_scale.scale),
  972. )
  973. if self.weight_scale.to_aux:
  974. net.WeightScale(
  975. [param_squared_sum, lr_iteration],
  976. [param_squared_sum],
  977. stepsize=self.weight_scale.stepsize,
  978. upper_bound_iter=self.weight_scale.upper_bound_iter,
  979. scale=float(self.weight_scale.scale),
  980. )
  981. def scale_learning_rate(self, scale):
  982. self.alpha *= scale
  983. return
  984. class WngradOptimizer(Optimizer):
  985. def __init__(
  986. self,
  987. alpha=1.0,
  988. epsilon=1e-9,
  989. policy="fixed",
  990. sparse_dedup_aggregator=None,
  991. engine="",
  992. moment_init=100.0,
  993. lars=None,
  994. output_effective_lr=False,
  995. output_effective_lr_and_update=False,
  996. **kwargs
  997. ):
  998. super(WngradOptimizer, self).__init__()
  999. self.alpha = alpha
  1000. self.epsilon = epsilon
  1001. self.policy = policy
  1002. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  1003. self.engine = engine
  1004. self.moment_init = moment_init
  1005. self.lars = lars
  1006. self.output_effective_lr = output_effective_lr
  1007. self.output_effective_lr_and_update = output_effective_lr_and_update
  1008. self.init_kwargs = kwargs
  1009. def _run(self, net, param_init_net, param_info):
  1010. param = param_info.blob
  1011. grad = param_info.grad
  1012. if self.alpha <= 0:
  1013. return
  1014. self._clear_local_lr_multiplier()
  1015. if self.lars is not None and not isinstance(grad, core.GradientSlice):
  1016. assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
  1017. self.lars
  1018. )
  1019. wd, trust, lr_max = self.create_lars_inputs(
  1020. param_init_net, 0.0, 1.0, np.finfo(np.float32).max
  1021. )
  1022. lr_lars_multiplier = net.Lars(
  1023. [param, grad, wd, trust, lr_max],
  1024. self.make_unique_blob_name(str(param) + "_lars"),
  1025. offset=self.lars,
  1026. lr_min=0.0,
  1027. )
  1028. current_scope = scope.CurrentDeviceScope()
  1029. self._add_local_lr_multiplier(
  1030. lr_lars_multiplier,
  1031. is_gpu_blob=(
  1032. current_scope is not None
  1033. and core.IsGPUDeviceType(current_scope.device_type)
  1034. ),
  1035. )
  1036. lr, _ = self.build_lr(
  1037. net,
  1038. param_init_net,
  1039. base_learning_rate=self.alpha,
  1040. policy=self.policy,
  1041. **(self.init_kwargs)
  1042. )
  1043. moment = param_init_net.ConstantFill(
  1044. [], str(param) + "_moment", shape=[1], value=self.moment_init
  1045. )
  1046. self._aux_params.local.append(moment)
  1047. if isinstance(grad, core.GradientSlice):
  1048. grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
  1049. net.SparseWngrad(
  1050. [param, moment, grad.indices, grad.values, lr],
  1051. [param, moment],
  1052. epsilon=self.epsilon,
  1053. engine=self.engine,
  1054. )
  1055. else:
  1056. output_args = [param, moment]
  1057. if self.output_effective_lr_and_update:
  1058. output_args.append(str(param) + "_effective_lr")
  1059. output_args.append(str(param) + "_update")
  1060. elif self.output_effective_lr:
  1061. output_args.append(str(param) + "_effective_lr")
  1062. net.Wngrad(
  1063. [param, moment, grad, lr],
  1064. output_args,
  1065. epsilon=self.epsilon,
  1066. engine=self.engine,
  1067. )
  1068. def scale_learning_rate(self, scale):
  1069. self.alpha *= scale
  1070. return
  1071. class StormOptimizer(Optimizer):
  1072. def __init__(
  1073. self,
  1074. lr=0.1,
  1075. momentum=10.0,
  1076. beta=0.1,
  1077. grad_sq_init=0.01,
  1078. policy="fixed",
  1079. sparse_dedup_aggregator=None,
  1080. lars=None,
  1081. **kwargs
  1082. ):
  1083. """Constructor function to add STORM Optimizer
  1084. Args:
  1085. lr: learning rate scaling (called k in the original paper)
  1086. momentum: momentum scaling (called c in the original paper)
  1087. beta: initial value of denominator in adaptive learning rate (
  1088. called c in the original paper)
  1089. grad_sq_init: initial value of gradient squared accumulator.
  1090. policy: specifies how learning rate should be applied, options are
  1091. 'fixed', 'step', 'exp', etc.
  1092. sparse_dedup_aggregator: specifies deduplication strategy for
  1093. gradient slices. Works while using sparse gradients. Options
  1094. include 'mean' and 'sum'.
  1095. lars: lars offset.
  1096. """
  1097. super(StormOptimizer, self).__init__()
  1098. self.lr = lr
  1099. self.momentum = momentum
  1100. self.beta = beta
  1101. self.grad_sq_init = grad_sq_init
  1102. self.policy = policy
  1103. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  1104. self.lars = lars
  1105. self.init_kwargs = kwargs
  1106. def _run(self, net, param_init_net, param_info):
  1107. param = param_info.blob
  1108. grad = param_info.grad
  1109. if self.lr <= 0:
  1110. return
  1111. self._clear_local_lr_multiplier()
  1112. if self.lars is not None and not isinstance(grad, core.GradientSlice):
  1113. assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
  1114. self.lars
  1115. )
  1116. wd, trust, lr_max = self.create_lars_inputs(
  1117. param_init_net, 0.0, 1.0, np.finfo(np.float32).max
  1118. )
  1119. lr_lars_multiplier = net.Lars(
  1120. [param, grad, wd, trust, lr_max],
  1121. self.make_unique_blob_name(str(param) + "_lars"),
  1122. offset=self.lars,
  1123. lr_min=0.0,
  1124. )
  1125. current_scope = scope.CurrentDeviceScope()
  1126. self._add_local_lr_multiplier(
  1127. lr_lars_multiplier,
  1128. is_gpu_blob=(
  1129. current_scope is not None
  1130. and core.IsGPUDeviceType(current_scope.device_type)
  1131. ),
  1132. )
  1133. lr, _ = self.build_lr(
  1134. net,
  1135. param_init_net,
  1136. base_learning_rate=self.lr,
  1137. policy=self.policy,
  1138. **(self.init_kwargs)
  1139. )
  1140. moment = param_init_net.ConstantFill(param, str(param) + "_moment", value=0.0)
  1141. self._aux_params.local.append(moment)
  1142. grad_sq_sum = param_init_net.ConstantFill(
  1143. [], str(param) + "_grad_sq_sum", shape=[1], value=self.grad_sq_init
  1144. )
  1145. self._aux_params.local.append(grad_sq_sum)
  1146. if isinstance(grad, core.GradientSlice):
  1147. grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
  1148. net.SparseStorm(
  1149. [param, moment, grad_sq_sum, grad.values, grad.indices, lr],
  1150. [param, moment, grad_sq_sum],
  1151. momentum=self.momentum,
  1152. beta=self.beta,
  1153. )
  1154. else:
  1155. net.Storm(
  1156. [param, moment, grad_sq_sum, grad, lr],
  1157. [param, moment, grad_sq_sum],
  1158. momentum=self.momentum,
  1159. beta=self.beta,
  1160. )
  1161. def scale_learning_rate(self, scale):
  1162. self.lr *= scale
  1163. class AdadeltaOptimizer(Optimizer):
  1164. def __init__(
  1165. self,
  1166. alpha=0.01,
  1167. epsilon=1e-4,
  1168. decay=0.95,
  1169. policy="fixed",
  1170. sparse_dedup_aggregator=None,
  1171. engine="",
  1172. **kwargs
  1173. ):
  1174. """Constructor function to add Adadelta Optimizer
  1175. Args:
  1176. alpha: learning rate
  1177. epsilon: attribute of Adadelta to avoid numerical issues
  1178. decay: attribute of Adadelta to decay the squared gradient sum
  1179. policy: specifies how learning rate should be applied, options are
  1180. "fixed", "step", "exp", etc.
  1181. sparse_dedup_aggregator: specifies deduplication strategy for
  1182. gradient slices. Works while using sparse gradients. Options
  1183. include "mean" and "sum".
  1184. engine: the engine used, options include "", "CUDNN", etc.
  1185. """
  1186. super(AdadeltaOptimizer, self).__init__()
  1187. self.alpha = alpha
  1188. self.epsilon = epsilon
  1189. self.decay = decay
  1190. self.policy = policy
  1191. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  1192. self.engine = engine
  1193. self.init_kwargs = kwargs
  1194. def _run(self, net, param_init_net, param_info):
  1195. param = param_info.blob
  1196. grad = param_info.grad
  1197. if self.alpha <= 0:
  1198. return
  1199. lr, _ = self.build_lr(
  1200. net,
  1201. param_init_net,
  1202. base_learning_rate=self.alpha,
  1203. policy=self.policy,
  1204. **(self.init_kwargs)
  1205. )
  1206. moment = param_init_net.ConstantFill(
  1207. [param], str(param) + "_squared_moment", value=0.0
  1208. )
  1209. moment_update = param_init_net.ConstantFill(
  1210. [param], str(param) + "_squared_moment_update", value=0.0
  1211. )
  1212. self._aux_params.local.append(moment)
  1213. self._aux_params.local.append(moment_update)
  1214. if isinstance(grad, core.GradientSlice):
  1215. grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
  1216. net.SparseAdadelta(
  1217. [param, moment, moment_update, grad.indices, grad.values, lr],
  1218. [param, moment, moment_update],
  1219. epsilon=self.epsilon,
  1220. decay=self.decay,
  1221. engine=self.engine,
  1222. )
  1223. else:
  1224. net.Adadelta(
  1225. [param, moment, moment_update, grad, lr],
  1226. [param, moment, moment_update],
  1227. epsilon=self.epsilon,
  1228. decay=self.decay,
  1229. engine=self.engine,
  1230. )
  1231. def scale_learning_rate(self, scale):
  1232. self.alpha *= scale
  1233. return
  1234. class FtrlOptimizer(Optimizer):
  1235. def __init__(
  1236. self,
  1237. alpha=0.01,
  1238. beta=1e-4,
  1239. lambda1=0,
  1240. lambda2=0,
  1241. sparse_dedup_aggregator=None,
  1242. engine="",
  1243. ):
  1244. super(FtrlOptimizer, self).__init__()
  1245. self.alpha = alpha
  1246. self.beta = beta
  1247. self.lambda1 = lambda1
  1248. self.lambda2 = lambda2
  1249. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  1250. self.engine = engine
  1251. def _run(self, net, param_init_net, param_info):
  1252. param = param_info.blob
  1253. grad = param_info.grad
  1254. if self.alpha <= 0:
  1255. return
  1256. nz = param_init_net.ConstantFill(
  1257. [param], str(param) + "_ftrl_nz", extra_shape=[2], value=0.0
  1258. )
  1259. self._aux_params.local.append(nz)
  1260. if isinstance(grad, core.GradientSlice):
  1261. grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
  1262. net.SparseFtrl(
  1263. [param, nz, grad.indices, grad.values],
  1264. [param, nz],
  1265. engine=self.engine,
  1266. alpha=self.alpha,
  1267. beta=self.beta,
  1268. lambda1=self.lambda1,
  1269. lambda2=self.lambda2,
  1270. )
  1271. else:
  1272. net.Ftrl(
  1273. [param, nz, grad],
  1274. [param, nz],
  1275. engine=self.engine,
  1276. alpha=self.alpha,
  1277. beta=self.beta,
  1278. lambda1=self.lambda1,
  1279. lambda2=self.lambda2,
  1280. )
  1281. def scale_learning_rate(self, scale):
  1282. self.alpha *= scale
  1283. return
  1284. class GFtrlOptimizer(Optimizer):
  1285. """Group Lasso FTRL Optimizer."""
  1286. def __init__(
  1287. self,
  1288. alpha=0.01,
  1289. beta=1e-4,
  1290. lambda1=0,
  1291. lambda2=0,
  1292. sparse_dedup_aggregator=None,
  1293. engine="",
  1294. ):
  1295. super(GFtrlOptimizer, self).__init__()
  1296. self.alpha = alpha
  1297. self.beta = beta
  1298. self.lambda1 = lambda1
  1299. self.lambda2 = lambda2
  1300. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  1301. self.engine = engine
  1302. def _run(self, net, param_init_net, param_info):
  1303. param = param_info.blob
  1304. grad = param_info.grad
  1305. if self.alpha <= 0:
  1306. return
  1307. nz = param_init_net.ConstantFill(
  1308. [param], str(param) + "_gftrl_nz", extra_shape=[2], value=0.0
  1309. )
  1310. self._aux_params.local.append(nz)
  1311. net.GFtrl(
  1312. [param, nz, grad],
  1313. [param, nz],
  1314. engine=self.engine,
  1315. alpha=self.alpha,
  1316. beta=self.beta,
  1317. lambda1=self.lambda1,
  1318. lambda2=self.lambda2,
  1319. )
  1320. def scale_learning_rate(self, scale):
  1321. self.alpha *= scale
  1322. return
  1323. class AdamOptimizer(Optimizer):
  1324. def __init__(
  1325. self,
  1326. alpha=0.001,
  1327. beta1=0.9,
  1328. beta2=0.999,
  1329. epsilon=1e-8,
  1330. policy="fixed",
  1331. use_lr_adaption=False,
  1332. lr_alpha=0.01,
  1333. normalized_lr_adaption=True,
  1334. sparse_dedup_aggregator=None,
  1335. rowWise=False,
  1336. engine="",
  1337. enableRAdam=False,
  1338. use_smart_decay=False, # See https://fburl.com/2jdiwrhy for context.
  1339. **kwargs
  1340. ):
  1341. super(AdamOptimizer, self).__init__()
  1342. self.alpha = alpha
  1343. self.beta1 = beta1
  1344. self.beta2 = beta2
  1345. self.epsilon = epsilon
  1346. self.policy = policy
  1347. self.use_lr_adaption = use_lr_adaption
  1348. self.lr_alpha = lr_alpha
  1349. self.normalized_lr_adaption = normalized_lr_adaption
  1350. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  1351. self.rowWise = rowWise
  1352. self.engine = engine
  1353. self.enableRAdam = enableRAdam
  1354. if use_smart_decay:
  1355. if rowWise:
  1356. raise NotImplementedError(('Smart decay is not implemented for rowWise Adam. '
  1357. 'Set rowWise or use_smart_decay to False.'))
  1358. if enableRAdam:
  1359. raise NotImplementedError(('Smart decay is not implemented for RAdam. '
  1360. 'Set enableRAdam or use_smart_decay to False.'))
  1361. if use_lr_adaption:
  1362. raise NotImplementedError(('Smart decay is not implemented with lr_adaption. '
  1363. 'Set use_lr_adaption or use_smart_decay to False.'))
  1364. self.use_smart_decay = use_smart_decay
  1365. self.init_kwargs = kwargs
  1366. def _run(self, net, param_init_net, param_info):
  1367. param = param_info.blob
  1368. grad = param_info.grad
  1369. if self.alpha <= 0:
  1370. return
  1371. lr, iteration = self.build_lr(
  1372. net,
  1373. param_init_net,
  1374. base_learning_rate=self.alpha,
  1375. policy=self.policy,
  1376. **(self.init_kwargs)
  1377. )
  1378. m1 = param_init_net.ConstantFill([param], param + "_first_moment", value=0.0)
  1379. if self.rowWise:
  1380. shapes, types = workspace.InferShapesAndTypes([param_init_net])
  1381. m2 = param_init_net.ConstantFill(
  1382. [], param + "_avg_second_moment", shape=[shapes[param][0]], value=0.0
  1383. )
  1384. else:
  1385. m2 = param_init_net.ConstantFill(
  1386. [param], param + "_second_moment", value=0.0
  1387. )
  1388. # Initialize "minibatch in which this parameter was last seen" for smart decay.
  1389. if self.use_smart_decay:
  1390. shapes, _ = workspace.InferShapesAndTypes([param_init_net])
  1391. last_seen = param_init_net.ConstantFill(
  1392. [], param + "_last_seen", shape=[shapes[param][0]], value=0, dtype=core.DataType.INT64
  1393. )
  1394. self._aux_params.local.append(last_seen)
  1395. self._aux_params.shared.append(iteration)
  1396. self._aux_params.local.append(m1)
  1397. self._aux_params.local.append(m2)
  1398. if self.rowWise:
  1399. assert isinstance(grad, core.GradientSlice), (
  1400. "If SparseAdam with rowWise=True, gradient must be "
  1401. "a gradientslice. PLease ensure that rowWise is not enabled "
  1402. "for the dense Adam optimizer, as it is not supported."
  1403. )
  1404. output_blobs = [param, m1, m2]
  1405. if self.use_smart_decay:
  1406. output_blobs.append(last_seen)
  1407. if self.use_lr_adaption:
  1408. effective_grad = str(param) + "_effective_grad"
  1409. output_blobs.append(effective_grad)
  1410. if isinstance(grad, core.GradientSlice):
  1411. grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
  1412. if self.rowWise:
  1413. op = "RowWiseSparseAdam"
  1414. elif self.use_smart_decay:
  1415. op = "SmartDecaySparseAdam"
  1416. else:
  1417. op = "SparseAdam"
  1418. # Currently, only SparseAdam support RAdam, other Adam Ops will support later
  1419. if op == "SparseAdam":
  1420. net.__getattr__(op)(
  1421. [param, m1, m2, grad.indices, grad.values, lr, iteration],
  1422. output_blobs,
  1423. beta1=self.beta1,
  1424. beta2=self.beta2,
  1425. epsilon=self.epsilon,
  1426. enableRAdam=self.enableRAdam,
  1427. )
  1428. elif op == "SmartDecaySparseAdam":
  1429. net.__getattr__(op)(
  1430. [param, m1, m2, last_seen, grad.indices, grad.values, lr, iteration],
  1431. output_blobs,
  1432. beta1=self.beta1,
  1433. beta2=self.beta2,
  1434. epsilon=self.epsilon,
  1435. )
  1436. else:
  1437. assert (
  1438. not self.enableRAdam
  1439. ), "Currently, RowWiseSparseAdam is not supported by RAdam!"
  1440. net.__getattr__(op)(
  1441. [param, m1, m2, grad.indices, grad.values, lr, iteration],
  1442. output_blobs,
  1443. beta1=self.beta1,
  1444. beta2=self.beta2,
  1445. epsilon=self.epsilon,
  1446. )
  1447. if self.use_lr_adaption:
  1448. net.LearningRateAdaption(
  1449. [lr, grad.values, effective_grad],
  1450. [lr],
  1451. lr_alpha=self.lr_alpha,
  1452. normalized_lr_adaption=self.normalized_lr_adaption,
  1453. )
  1454. else:
  1455. net.Adam(
  1456. [param, m1, m2, grad, lr, iteration],
  1457. output_blobs,
  1458. beta1=self.beta1,
  1459. beta2=self.beta2,
  1460. epsilon=self.epsilon,
  1461. )
  1462. if self.use_lr_adaption:
  1463. net.LearningRateAdaption(
  1464. [lr, grad, effective_grad],
  1465. [lr],
  1466. lr_alpha=self.lr_alpha,
  1467. normalized_lr_adaption=self.normalized_lr_adaption,
  1468. )
  1469. def scale_learning_rate(self, scale):
  1470. self.alpha *= scale
  1471. return
  1472. class DecayAdagradOptimizer(Optimizer):
  1473. def __init__(
  1474. self,
  1475. alpha=0.01,
  1476. beta1=0.0,
  1477. beta2=0.999,
  1478. epsilon=0.1,
  1479. weight_decay=0.0,
  1480. ema_options=None,
  1481. bias_correction_first=True,
  1482. policy="fixed",
  1483. engine="",
  1484. **kwargs
  1485. ):
  1486. super(DecayAdagradOptimizer, self).__init__()
  1487. self.alpha = alpha
  1488. self.beta1 = beta1
  1489. self.beta2 = beta2
  1490. self.epsilon = epsilon
  1491. self.weight_decay = weight_decay
  1492. self.bias_correction_first = bias_correction_first
  1493. self.policy = policy
  1494. self.engine = engine
  1495. self.init_kwargs = kwargs
  1496. self._process_ema_options(ema_options)
  1497. def _process_ema_options(self, ema_options):
  1498. self.ema_enabled = True if ema_options else False
  1499. if self.ema_enabled:
  1500. self.ema_start = ema_options.get("ema_start", None)
  1501. self.ema_end = ema_options.get("ema_end", None)
  1502. self.ema_step = ema_options.get("ema_step", None)
  1503. self.ema_alpha = ema_options.get("ema_alpha", None)
  1504. def _run(self, net, param_init_net, param_info):
  1505. param = param_info.blob
  1506. grad = param_info.grad
  1507. if self.alpha <= 0:
  1508. return
  1509. lr, iteration = self.build_lr(
  1510. net,
  1511. param_init_net,
  1512. base_learning_rate=self.alpha,
  1513. policy=self.policy,
  1514. **(self.init_kwargs)
  1515. )
  1516. if isinstance(grad, core.GradientSlice):
  1517. # hack for position weighted.
  1518. param_squared_sum = param_init_net.ConstantFill([param], param + "_squared_sum", value=0.0)
  1519. self._aux_params.local.append(param_squared_sum)
  1520. output_blobs = [param, param_squared_sum]
  1521. net.SparseAdagrad(
  1522. [param, param_squared_sum, grad.indices, grad.values, lr],
  1523. output_blobs,
  1524. epsilon=self.epsilon,
  1525. )
  1526. else:
  1527. m1 = param_init_net.ConstantFill([param], param + "_first_mo1ment", value=0.0)
  1528. m2 = param_init_net.ConstantFill([param], param + "_second_moment", value=0.0)
  1529. self._aux_params.shared.append(iteration)
  1530. self._aux_params.local.append(m1)
  1531. self._aux_params.local.append(m2)
  1532. output_blobs = [param, m1, m2]
  1533. net.DecayAdagrad(
  1534. [param, m1, m2, grad, lr, iteration],
  1535. output_blobs,
  1536. beta1=self.beta1,
  1537. beta2=self.beta2,
  1538. epsilon=self.epsilon,
  1539. weight_decay=self.weight_decay,
  1540. bias_correction_first=self.bias_correction_first,
  1541. )
  1542. if self.ema_enabled:
  1543. param_ema = str(param) + "_ema"
  1544. if not param_init_net.BlobIsDefined(param_ema):
  1545. param_init_net.ConstantFill([param], param_ema, value=0.0)
  1546. self._aux_params.local.append(param_ema)
  1547. net.EMA(
  1548. [param, param_ema, iteration],
  1549. [param, param_ema],
  1550. ema_start=self.ema_start,
  1551. ema_end=self.ema_end,
  1552. ema_step=self.ema_step,
  1553. ema_alpha=self.ema_alpha,
  1554. )
  1555. def scale_learning_rate(self, scale):
  1556. self.alpha *= scale
  1557. return
  1558. class YellowFinOptimizer(Optimizer):
  1559. """YellowFin: An automatic tuner for momentum SGD
  1560. See https://arxiv.org/abs/1706.03471 for more details. This implementation
  1561. has separate learning rate and momentum per each parameter."""
  1562. def __init__(
  1563. self,
  1564. alpha=0.1,
  1565. mu=0.0,
  1566. beta=0.999,
  1567. curv_win_width=20,
  1568. zero_debias=True,
  1569. epsilon=0.1 ** 6,
  1570. policy="fixed",
  1571. sparse_dedup_aggregator=None,
  1572. **kwargs
  1573. ):
  1574. super(YellowFinOptimizer, self).__init__()
  1575. self.alpha = alpha
  1576. self.mu = mu
  1577. self.beta = beta
  1578. self.curv_win_width = curv_win_width
  1579. self.zero_debias = zero_debias
  1580. self.epsilon = epsilon
  1581. self.policy = policy
  1582. self.sparse_dedup_aggregator = sparse_dedup_aggregator
  1583. self.init_kwargs = kwargs
  1584. def _run(self, net, param_init_net, param_info):
  1585. # Note: This is number of persistent scalars in YellowFin optimizer.
  1586. # It should always be the number of scalars being used. The same
  1587. # number should be used in class for the operation.
  1588. SCALARS_MEMORY_SIZE = 5
  1589. param = param_info.blob
  1590. grad = param_info.grad
  1591. moment = param_init_net.ConstantFill([param], param + "_moment", value=0.0)
  1592. curv_win = param_init_net.ConstantFill(
  1593. [], param + "_curv_win", shape=[self.curv_win_width], value=0.0
  1594. )
  1595. g_avg = param_init_net.ConstantFill([param], param + "_g_avg", value=0.0)
  1596. g2_avg = param_init_net.ConstantFill([param], param + "_g2_avg", value=0.0)
  1597. lr_avg = param_init_net.ConstantFill(
  1598. [], param + "_lr_avg", shape=[1], value=self.alpha
  1599. )
  1600. mu_avg = param_init_net.ConstantFill(
  1601. [], param + "_mu_avg", shape=[1], value=self.mu
  1602. )
  1603. scalars_memory = param_init_net.ConstantFill(
  1604. [], param + "_scalars_memory", shape=[SCALARS_MEMORY_SIZE], value=0.0
  1605. )
  1606. assert self.alpha > 0
  1607. assert not isinstance(
  1608. grad, core.GradientSlice
  1609. ), "YellowFin does not support sparse gradients"
  1610. iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=0)
  1611. self._aux_params.shared.append(iteration)
  1612. self._aux_params.local.append(moment)
  1613. self._aux_params.local.append(lr_avg)
  1614. self._aux_params.local.append(mu_avg)
  1615. self._aux_params.local.append(curv_win)
  1616. self._aux_params.local.append(g_avg)
  1617. self._aux_params.local.append(g2_avg)
  1618. self._aux_params.local.append(scalars_memory)
  1619. yf_in_out_args = [
  1620. param,
  1621. moment,
  1622. lr_avg,
  1623. mu_avg,
  1624. curv_win,
  1625. g_avg,
  1626. g2_avg,
  1627. scalars_memory,
  1628. ]
  1629. net.YellowFin(
  1630. yf_in_out_args + [grad, iteration],
  1631. yf_in_out_args,
  1632. beta=self.beta,
  1633. epsilon=self.epsilon,
  1634. curv_win_width=self.curv_win_width,
  1635. zero_debias=self.zero_debias,
  1636. )
  1637. def scale_learning_rate(self, scale):
  1638. self.alpha *= scale
  1639. return
  1640. class RmsPropOptimizer(Optimizer):
  1641. def __init__(
  1642. self,
  1643. alpha=0.01,
  1644. decay=0.9,
  1645. momentum=0.0,
  1646. epsilon=1e-5,
  1647. policy="fixed",
  1648. engine="",
  1649. **kwargs
  1650. ):
  1651. super(RmsPropOptimizer, self).__init__()
  1652. self.alpha = alpha
  1653. self.decay = decay
  1654. self.momentum = momentum
  1655. self.epsilon = epsilon
  1656. self.policy = policy
  1657. self.engine = engine
  1658. self.init_kwargs = kwargs
  1659. def _run(self, net, param_init_net, param_info):
  1660. param = param_info.blob
  1661. grad = param_info.grad
  1662. assert self.alpha > 0
  1663. assert not isinstance(
  1664. grad, core.GradientSlice
  1665. ), "RmsPropOptimizer doesn't support sparse gradients"
  1666. dev = scope.CurrentDeviceScope()
  1667. if dev is None:
  1668. dev = core.DeviceOption(caffe2_pb2.CPU)
  1669. ONE = param_init_net.ConstantFill(
  1670. [], "ONE_{}_{}".format(dev.device_type, dev.device_id), shape=[1], value=1.0
  1671. )
  1672. lr, _ = self.build_lr(
  1673. net,
  1674. param_init_net,
  1675. base_learning_rate=-self.alpha,
  1676. policy=self.policy,
  1677. **(self.init_kwargs)
  1678. )
  1679. grad_o = param_init_net.ConstantFill(
  1680. [param], str(param) + "_grad_o", values=0.0
  1681. )
  1682. ms = param_init_net.ConstantFill(
  1683. [param], str(param) + "_mean_squares", values=0.0
  1684. )
  1685. mom = param_init_net.ConstantFill([param], str(param) + "_momentum", values=0.0)
  1686. self._aux_params.local.append(ms)
  1687. self._aux_params.local.append(mom)
  1688. net.RmsProp(
  1689. [grad, ms, mom, ONE],
  1690. [grad_o, ms, mom],
  1691. decay=self.decay,
  1692. momentum=self.momentum,
  1693. epsilon=self.epsilon,
  1694. engine=self.engine,
  1695. )
  1696. net.MomentumSGDUpdate([grad_o, mom, lr, param], [grad_o, mom, param])
  1697. def scale_learning_rate(self, scale):
  1698. self.alpha *= scale
  1699. return
  1700. def _get_param_to_device(model):
  1701. # Infer blob devices by going through the net and param_init_net
  1702. # ops and observing the device used to create or use the blob.
  1703. param_to_device = core.InferBlobDevices(model.net)
  1704. param_to_device.update(core.InferBlobDevices(model.param_init_net))
  1705. return param_to_device
  1706. def get_param_device(param_name, grad, param_to_device=None, default_device=None):
  1707. device = default_device
  1708. param_to_device = param_to_device or {}
  1709. # We first check if parameter's device has been inferred. If not,
  1710. # we check the gradient. This can happen if parameter is not output
  1711. # by any blob but created by a FetchBlob.
  1712. if param_name in param_to_device:
  1713. device = param_to_device[param_name]
  1714. else:
  1715. if isinstance(grad, core.GradientSlice):
  1716. grad = grad
  1717. if str(grad.values) in param_to_device:
  1718. device = param_to_device[str(grad.values)]
  1719. elif str(grad.indices) in param_to_device:
  1720. device = param_to_device[str(grad.indices)]
  1721. else:
  1722. grad_name = str(grad)
  1723. if grad_name in param_to_device:
  1724. device = param_to_device[grad_name]
  1725. assert device is not None, "Cannot infer device for {}: no op creates it".format(
  1726. param_name
  1727. )
  1728. return device
  1729. def get_lr_injection():
  1730. """
  1731. Gets current value for lr_injection, a multiplier for all base
  1732. learning rates.
  1733. Must set allow_lr_injection=True when building optimizer, as it
  1734. relies on synchronization over CPU.
  1735. """
  1736. return workspace.FetchBlob(_LEARNING_RATE_INJECTION)
  1737. def set_lr_injection(lr_injection_value):
  1738. """
  1739. Sets lr_injection, a multiplier for all base learning rates.
  1740. Must set allow_lr_injection=True when building optimizer, as it
  1741. relies on synchronization over CPU.
  1742. """
  1743. workspace.FeedBlob(
  1744. _LEARNING_RATE_INJECTION,
  1745. np.array([float(lr_injection_value)], dtype=np.float32),
  1746. )
  1747. def _calc_norm_ratio(model, params, name_scope, param_to_device, max_gradient_norm):
  1748. with core.NameScope(name_scope):
  1749. grad_squared_sums = []
  1750. for i, param in enumerate(params):
  1751. device = get_param_device(str(param.blob), param.grad, param_to_device)
  1752. with core.DeviceScope(device):
  1753. grad = (
  1754. param.grad
  1755. if not isinstance(param.grad, core.GradientSlice)
  1756. else param.grad.values
  1757. )
  1758. grad_squared_sum_name = "grad_{}_squared_sum".format(i)
  1759. grad_squared_sum = model.net.SumSqrElements(grad, grad_squared_sum_name)
  1760. grad_squared_sum_cpu = model.net.EnsureCPUOutput(grad_squared_sum)
  1761. grad_squared_sums.append(grad_squared_sum_cpu)
  1762. with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
  1763. grad_squared_full_sum = model.net.Sum(
  1764. grad_squared_sums, "grad_squared_full_sum"
  1765. )
  1766. global_norm = model.net.Pow(
  1767. grad_squared_full_sum, "global_norm", exponent=0.5
  1768. )
  1769. clip_norm = model.param_init_net.ConstantFill(
  1770. [], "clip_norm", shape=[], value=float(max_gradient_norm)
  1771. )
  1772. max_norm = model.net.Max([global_norm, clip_norm], "max_norm")
  1773. norm_ratio = model.net.Div([clip_norm, max_norm], "norm_ratio")
  1774. return norm_ratio
  1775. def _build(
  1776. model,
  1777. optimizer,
  1778. weights_only=False,
  1779. use_param_info_optim=True,
  1780. max_gradient_norm=None,
  1781. allow_lr_injection=False,
  1782. ):
  1783. param_to_device = _get_param_to_device(model)
  1784. # Validate there are no duplicate params
  1785. model.Validate()
  1786. params = []
  1787. for param_info in model.GetOptimizationParamInfo():
  1788. if weights_only and param_info.blob not in model.weights:
  1789. continue
  1790. params.append(param_info)
  1791. lr_multiplier = None
  1792. if max_gradient_norm is not None:
  1793. lr_multiplier = _calc_norm_ratio(
  1794. model,
  1795. params,
  1796. "norm_clipped_grad_update",
  1797. param_to_device,
  1798. max_gradient_norm,
  1799. )
  1800. if allow_lr_injection:
  1801. if not model.net.BlobIsDefined(_LEARNING_RATE_INJECTION):
  1802. lr_injection = model.param_init_net.ConstantFill(
  1803. [], _LEARNING_RATE_INJECTION, shape=[1], value=1.0
  1804. )
  1805. else:
  1806. lr_injection = _LEARNING_RATE_INJECTION
  1807. if lr_multiplier is None:
  1808. lr_multiplier = lr_injection
  1809. else:
  1810. lr_multiplier = model.net.Mul(
  1811. [lr_multiplier, lr_injection], "lr_multiplier", broadcast=1
  1812. )
  1813. optimizer.add_lr_multiplier(lr_multiplier)
  1814. for param_info in params:
  1815. param_name = str(param_info.blob)
  1816. device = get_param_device(param_name, param_info.grad, param_to_device)
  1817. with core.DeviceScope(device):
  1818. if param_info.optimizer and use_param_info_optim:
  1819. param_info.optimizer(model.net, model.param_init_net, param_info)
  1820. else:
  1821. optimizer(model.net, model.param_init_net, param_info)
  1822. return optimizer
  1823. def add_weight_decay(model, weight_decay):
  1824. """Adds a decay to weights in the model.
  1825. This is a form of L2 regularization.
  1826. Args:
  1827. weight_decay: strength of the regularization
  1828. """
  1829. _build(
  1830. model,
  1831. WeightDecayBuilder(weight_decay=weight_decay),
  1832. weights_only=True,
  1833. use_param_info_optim=False,
  1834. )
  1835. def build_sgd(
  1836. model,
  1837. base_learning_rate,
  1838. max_gradient_norm=None,
  1839. allow_lr_injection=False,
  1840. **kwargs
  1841. ):
  1842. sgd_optimizer = SgdOptimizer(base_learning_rate, **kwargs)
  1843. return _build(
  1844. model,
  1845. sgd_optimizer,
  1846. max_gradient_norm=max_gradient_norm,
  1847. allow_lr_injection=allow_lr_injection,
  1848. )
  1849. def build_multi_precision_sgd(
  1850. model,
  1851. base_learning_rate,
  1852. max_gradient_norm=None,
  1853. allow_lr_injection=False,
  1854. **kwargs
  1855. ):
  1856. multi_prec_sgd_optimizer = MultiPrecisionSgdOptimizer(base_learning_rate, **kwargs)
  1857. return _build(
  1858. model,
  1859. multi_prec_sgd_optimizer,
  1860. max_gradient_norm=max_gradient_norm,
  1861. allow_lr_injection=allow_lr_injection,
  1862. )
  1863. def build_fp16_sgd(model, base_learning_rate, **kwargs):
  1864. fp16_sgd_optimizer = FP16SgdOptimizer(base_learning_rate, **kwargs)
  1865. return _build(model, fp16_sgd_optimizer)
  1866. def build_ftrl(model, engine="SIMD", **kwargs):
  1867. if engine == "SIMD":
  1868. assert core.IsOperator("Ftrl_ENGINE_SIMD")
  1869. assert core.IsOperator("SparseFtrl_ENGINE_SIMD")
  1870. ftrl_optimizer = FtrlOptimizer(engine=engine, **kwargs)
  1871. return _build(model, ftrl_optimizer)
  1872. def build_gftrl(model, engine="", **kwargs):
  1873. if engine == "SIMD":
  1874. assert core.IsOperator("GFtrl_ENGINE_SIMD")
  1875. gftrl_optimizer = GFtrlOptimizer(engine=engine, **kwargs)
  1876. return _build(model, gftrl_optimizer)
  1877. def build_adagrad(
  1878. model,
  1879. base_learning_rate,
  1880. parameters=None,
  1881. max_gradient_norm=None,
  1882. allow_lr_injection=False,
  1883. **kwargs
  1884. ):
  1885. adagrad_optimizer = AdagradOptimizer(alpha=base_learning_rate, **kwargs)
  1886. return _build(
  1887. model,
  1888. adagrad_optimizer,
  1889. max_gradient_norm=max_gradient_norm,
  1890. allow_lr_injection=allow_lr_injection,
  1891. )
  1892. def build_wngrad(
  1893. model,
  1894. base_learning_rate,
  1895. parameters=None,
  1896. max_gradient_norm=None,
  1897. allow_lr_injection=False,
  1898. **kwargs
  1899. ):
  1900. wngrad_optimizer = WngradOptimizer(alpha=base_learning_rate, **kwargs)
  1901. return _build(
  1902. model,
  1903. wngrad_optimizer,
  1904. max_gradient_norm=max_gradient_norm,
  1905. allow_lr_injection=allow_lr_injection,
  1906. )
  1907. def build_storm(
  1908. model,
  1909. base_learning_rate,
  1910. parameters=None,
  1911. max_gradient_norm=None,
  1912. allow_lr_injection=False,
  1913. **kwargs
  1914. ):
  1915. storm_optimizer = StormOptimizer(lr=base_learning_rate, **kwargs)
  1916. return _build(
  1917. model,
  1918. storm_optimizer,
  1919. max_gradient_norm=max_gradient_norm,
  1920. allow_lr_injection=allow_lr_injection,
  1921. )
  1922. def build_adadelta(
  1923. model,
  1924. base_learning_rate,
  1925. parameters=None,
  1926. max_gradient_norm=None,
  1927. allow_lr_injection=False,
  1928. **kwargs
  1929. ):
  1930. adadelta_optimizer = AdadeltaOptimizer(alpha=base_learning_rate, **kwargs)
  1931. return _build(
  1932. model,
  1933. adadelta_optimizer,
  1934. max_gradient_norm=max_gradient_norm,
  1935. allow_lr_injection=allow_lr_injection,
  1936. )
  1937. def build_adam(
  1938. model,
  1939. base_learning_rate,
  1940. max_gradient_norm=None,
  1941. allow_lr_injection=False,
  1942. **kwargs
  1943. ):
  1944. adam_optimizer = AdamOptimizer(alpha=base_learning_rate, **kwargs)
  1945. return _build(
  1946. model,
  1947. adam_optimizer,
  1948. max_gradient_norm=max_gradient_norm,
  1949. allow_lr_injection=allow_lr_injection,
  1950. )
  1951. def build_decay_adagrad(
  1952. model,
  1953. base_learning_rate,
  1954. max_gradient_norm=None,
  1955. allow_lr_injection=False,
  1956. **kwargs
  1957. ):
  1958. decay_adagrad_optimizer = DecayAdagradOptimizer(alpha=base_learning_rate, **kwargs)
  1959. return _build(
  1960. model,
  1961. decay_adagrad_optimizer,
  1962. max_gradient_norm=max_gradient_norm,
  1963. allow_lr_injection=allow_lr_injection,
  1964. )
  1965. def build_yellowfin(model, base_learning_rate=0.1, **kwargs):
  1966. yellowfin_optimizer = YellowFinOptimizer(alpha=base_learning_rate, **kwargs)
  1967. return _build(model, yellowfin_optimizer)
  1968. def build_rms_prop(
  1969. model,
  1970. base_learning_rate,
  1971. max_gradient_norm=None,
  1972. allow_lr_injection=False,
  1973. **kwargs
  1974. ):
  1975. rms_prop_optimizer = RmsPropOptimizer(alpha=base_learning_rate, **kwargs)
  1976. return _build(
  1977. model,
  1978. rms_prop_optimizer,
  1979. max_gradient_norm=max_gradient_norm,
  1980. allow_lr_injection=allow_lr_injection,
  1981. )