train.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. ## @package train
  2. # Module caffe2.python.helpers.train
  3. from caffe2.python import core, scope
  4. from caffe2.proto import caffe2_pb2
  5. def _get_weights(model, namescope=None):
  6. if namescope is None:
  7. namescope = scope.CurrentNameScope()
  8. if namescope == '':
  9. return model.weights[:]
  10. else:
  11. return [w for w in model.weights if w.GetNameScope() == namescope]
  12. def iter(model, blob_out, **kwargs):
  13. if 'device_option' in kwargs:
  14. del kwargs['device_option']
  15. model.param_init_net.ConstantFill(
  16. [],
  17. blob_out,
  18. shape=[1],
  19. value=0,
  20. dtype=core.DataType.INT64,
  21. device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
  22. **kwargs
  23. )
  24. return model.net.Iter(blob_out, blob_out, **kwargs)
  25. def accuracy(model, blob_in, blob_out, **kwargs):
  26. dev = kwargs['device_option'] if 'device_option' in kwargs \
  27. else scope.CurrentDeviceScope()
  28. is_cpu = dev is None or dev.device_type == caffe2_pb2.CPU
  29. # We support top_k > 1 only on CPU
  30. if not is_cpu and 'top_k' in kwargs and kwargs['top_k'] > 1:
  31. pred_host = model.net.CopyGPUToCPU(blob_in[0], blob_in[0] + "_host")
  32. label_host = model.net.CopyGPUToCPU(blob_in[1], blob_in[1] + "_host")
  33. # Now use the Host version of the accuracy op
  34. model.net.Accuracy(
  35. [pred_host, label_host],
  36. blob_out,
  37. device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
  38. **kwargs
  39. )
  40. else:
  41. model.net.Accuracy(blob_in, blob_out)
  42. def add_weight_decay(model, weight_decay):
  43. """Adds a decay to weights in the model.
  44. This is a form of L2 regularization.
  45. Args:
  46. weight_decay: strength of the regularization
  47. """
  48. if weight_decay <= 0.0:
  49. return
  50. wd = model.param_init_net.ConstantFill(
  51. [], 'wd', shape=[1], value=weight_decay
  52. )
  53. ONE = model.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
  54. for param in _get_weights(model):
  55. # Equivalent to: grad += wd * param
  56. grad = model.param_to_grad[param]
  57. model.net.WeightedSum(
  58. [grad, ONE, param, wd],
  59. grad,
  60. )