pairwise_similarity.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. ## @package dot_product
  2. # Module caffe2.python.layers.dot_product
  3. from caffe2.python import schema
  4. from caffe2.python.layers.layers import (
  5. ModelLayer,
  6. )
  7. class PairwiseSimilarity(ModelLayer):
  8. def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot',
  9. name='pairwise_similarity', **kwargs):
  10. super(PairwiseSimilarity, self).__init__(model, name, input_record, **kwargs)
  11. assert isinstance(input_record, schema.Struct), (
  12. "Incorrect input type. Expected Struct, but received: {0}".
  13. format(input_record))
  14. assert (
  15. ('all_embeddings' in input_record) ^
  16. ('x_embeddings' in input_record and 'y_embeddings' in input_record)
  17. ), (
  18. "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
  19. "should be given."
  20. )
  21. self.pairwise_similarity_func = pairwise_similarity_func
  22. if 'all_embeddings' in input_record:
  23. x_embeddings = input_record['all_embeddings']
  24. y_embeddings = input_record['all_embeddings']
  25. else:
  26. x_embeddings = input_record['x_embeddings']
  27. y_embeddings = input_record['y_embeddings']
  28. assert isinstance(x_embeddings, schema.Scalar), (
  29. "Incorrect input type for x. Expected Scalar, " +
  30. "but received: {0}".format(x_embeddings))
  31. assert isinstance(y_embeddings, schema.Scalar), (
  32. "Incorrect input type for y. Expected Scalar, " +
  33. "but received: {0}".format(y_embeddings)
  34. )
  35. if 'indices_to_gather' in input_record:
  36. indices_to_gather = input_record['indices_to_gather']
  37. assert isinstance(indices_to_gather, schema.Scalar), (
  38. "Incorrect type of indices_to_gather. "
  39. "Expected Scalar, but received: {0}".format(indices_to_gather)
  40. )
  41. self.indices_to_gather = indices_to_gather
  42. else:
  43. self.indices_to_gather = None
  44. self.x_embeddings = x_embeddings
  45. self.y_embeddings = y_embeddings
  46. dtype = x_embeddings.field_types()[0].base
  47. self.output_schema = schema.Scalar(
  48. (dtype, (output_dim,)),
  49. self.get_next_blob_reference('output')
  50. )
  51. def add_ops(self, net):
  52. if self.pairwise_similarity_func == "cosine_similarity":
  53. x_embeddings_norm = net.Normalize(self.x_embeddings(), axis=1)
  54. y_embeddings_norm = net.Normalize(self.y_embeddings(), axis=1)
  55. Y = net.BatchMatMul(
  56. [x_embeddings_norm, y_embeddings_norm],
  57. [self.get_next_blob_reference(x_embeddings_norm + '_matmul')],
  58. trans_b=1,
  59. )
  60. elif self.pairwise_similarity_func == "dot":
  61. Y = net.BatchMatMul(
  62. [self.x_embeddings(), self.y_embeddings()],
  63. [self.get_next_blob_reference(self.x_embeddings() + '_matmul')],
  64. trans_b=1,
  65. )
  66. else:
  67. raise NotImplementedError(
  68. "pairwise_similarity_func={} is not valid".format(
  69. self.pairwise_similarity_func
  70. )
  71. )
  72. if self.indices_to_gather:
  73. flattened = net.Flatten(
  74. Y, Y + '_flatten',
  75. )
  76. net.BatchGather(
  77. [flattened, self.indices_to_gather()],
  78. self.output_schema(),
  79. )
  80. else:
  81. net.Flatten(Y, self.output_schema())