label_smooth.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) 2016-present, Facebook, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. ##############################################################################
  15. # @package label_smooth
  16. # Module caffe2.python.layers.label_smooth
  17. from caffe2.python import core, schema
  18. from caffe2.python.layers.layers import ModelLayer
  19. import numpy as np
  20. class LabelSmooth(ModelLayer):
  21. def __init__(
  22. self, model, label, smooth_matrix, name='label_smooth', **kwargs
  23. ):
  24. super(LabelSmooth, self).__init__(model, name, label, **kwargs)
  25. self.label = label
  26. # shape as a list
  27. smooth_matrix = np.array(smooth_matrix).astype(np.float32).flatten()
  28. self.set_dim(smooth_matrix)
  29. self.set_smooth_matrix(smooth_matrix)
  30. self.output_schema = schema.Scalar(
  31. (np.float32, (self.dim, )),
  32. self.get_next_blob_reference('smoothed_label')
  33. )
  34. def set_dim(self, smooth_matrix):
  35. num_elements = smooth_matrix.size
  36. self.binary_prob_label = (num_elements == 2)
  37. if self.binary_prob_label:
  38. self.dim = 1
  39. else:
  40. assert np.sqrt(num_elements)**2 == num_elements
  41. self.dim = int(np.sqrt(num_elements))
  42. def set_smooth_matrix(self, smooth_matrix):
  43. if not self.binary_prob_label:
  44. self.smooth_matrix = self.model.add_global_constant(
  45. '%s_label_smooth_matrix' % self.name,
  46. array=smooth_matrix.reshape((self.dim, self.dim)),
  47. dtype=np.dtype(np.float32),
  48. )
  49. self.len = self.model.add_global_constant(
  50. '%s_label_dim' % self.name,
  51. array=self.dim,
  52. dtype=np.dtype(np.int64),
  53. )
  54. else:
  55. self.smooth_matrix = smooth_matrix
  56. def add_ops_for_binary_prob_label(self, net):
  57. if self.label.field_type().base != np.float32:
  58. float32_label = net.NextScopedBlob('float32_label')
  59. net.Cast([self.label()], [float32_label], to=core.DataType.FLOAT)
  60. else:
  61. float32_label = self.label()
  62. net.StumpFunc(
  63. float32_label,
  64. self.output_schema(),
  65. threshold=0.5,
  66. low_value=self.smooth_matrix[0],
  67. high_value=self.smooth_matrix[1],
  68. )
  69. def add_ops_for_categorical_label(self, net):
  70. if self.label.field_type().base != np.int64:
  71. int64_label = net.NextScopedBlob('int64_label')
  72. net.Cast([self.label()], [int64_label], to=core.DataType.INT64)
  73. else:
  74. int64_label = self.label()
  75. one_hot_label = net.NextScopedBlob('one_hot_label')
  76. net.OneHot([int64_label, self.len], [one_hot_label])
  77. net.MatMul([one_hot_label, self.smooth_matrix], self.output_schema())
  78. def add_ops(self, net):
  79. if self.binary_prob_label:
  80. self.add_ops_for_binary_prob_label(net)
  81. else:
  82. self.add_ops_for_categorical_label(net)