| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- ## @package dot_product
- # Module caffe2.python.layers.dot_product
- from caffe2.python import schema
- from caffe2.python.layers.layers import (
- ModelLayer,
- )
- class PairwiseSimilarity(ModelLayer):
- def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot',
- name='pairwise_similarity', **kwargs):
- super(PairwiseSimilarity, self).__init__(model, name, input_record, **kwargs)
- assert isinstance(input_record, schema.Struct), (
- "Incorrect input type. Expected Struct, but received: {0}".
- format(input_record))
- assert (
- ('all_embeddings' in input_record) ^
- ('x_embeddings' in input_record and 'y_embeddings' in input_record)
- ), (
- "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
- "should be given."
- )
- self.pairwise_similarity_func = pairwise_similarity_func
- if 'all_embeddings' in input_record:
- x_embeddings = input_record['all_embeddings']
- y_embeddings = input_record['all_embeddings']
- else:
- x_embeddings = input_record['x_embeddings']
- y_embeddings = input_record['y_embeddings']
- assert isinstance(x_embeddings, schema.Scalar), (
- "Incorrect input type for x. Expected Scalar, " +
- "but received: {0}".format(x_embeddings))
- assert isinstance(y_embeddings, schema.Scalar), (
- "Incorrect input type for y. Expected Scalar, " +
- "but received: {0}".format(y_embeddings)
- )
- if 'indices_to_gather' in input_record:
- indices_to_gather = input_record['indices_to_gather']
- assert isinstance(indices_to_gather, schema.Scalar), (
- "Incorrect type of indices_to_gather. "
- "Expected Scalar, but received: {0}".format(indices_to_gather)
- )
- self.indices_to_gather = indices_to_gather
- else:
- self.indices_to_gather = None
- self.x_embeddings = x_embeddings
- self.y_embeddings = y_embeddings
- dtype = x_embeddings.field_types()[0].base
- self.output_schema = schema.Scalar(
- (dtype, (output_dim,)),
- self.get_next_blob_reference('output')
- )
- def add_ops(self, net):
- if self.pairwise_similarity_func == "cosine_similarity":
- x_embeddings_norm = net.Normalize(self.x_embeddings(), axis=1)
- y_embeddings_norm = net.Normalize(self.y_embeddings(), axis=1)
- Y = net.BatchMatMul(
- [x_embeddings_norm, y_embeddings_norm],
- [self.get_next_blob_reference(x_embeddings_norm + '_matmul')],
- trans_b=1,
- )
- elif self.pairwise_similarity_func == "dot":
- Y = net.BatchMatMul(
- [self.x_embeddings(), self.y_embeddings()],
- [self.get_next_blob_reference(self.x_embeddings() + '_matmul')],
- trans_b=1,
- )
- else:
- raise NotImplementedError(
- "pairwise_similarity_func={} is not valid".format(
- self.pairwise_similarity_func
- )
- )
- if self.indices_to_gather:
- flattened = net.Flatten(
- Y, Y + '_flatten',
- )
- net.BatchGather(
- [flattened, self.indices_to_gather()],
- self.output_schema(),
- )
- else:
- net.Flatten(Y, self.output_schema())
|