hsm_util.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. ## @package hsm_util
  2. # Module caffe2.python.hsm_util
  3. from caffe2.proto import hsm_pb2
  4. '''
  5. Hierarchical softmax utility methods that can be used to:
  6. 1) create TreeProto structure given list of word_ids or NodeProtos
  7. 2) create HierarchyProto structure using the user-inputted TreeProto
  8. '''
  9. def create_node_with_words(words, name='node'):
  10. node = hsm_pb2.NodeProto()
  11. node.name = name
  12. for word in words:
  13. node.word_ids.append(word)
  14. return node
  15. def create_node_with_nodes(nodes, name='node'):
  16. node = hsm_pb2.NodeProto()
  17. node.name = name
  18. for child_node in nodes:
  19. new_child_node = node.children.add()
  20. new_child_node.MergeFrom(child_node)
  21. return node
  22. def create_hierarchy(tree_proto):
  23. max_index = 0
  24. def create_path(path, word):
  25. path_proto = hsm_pb2.PathProto()
  26. path_proto.word_id = word
  27. for entry in path:
  28. new_path_node = path_proto.path_nodes.add()
  29. new_path_node.index = entry[0]
  30. new_path_node.length = entry[1]
  31. new_path_node.target = entry[2]
  32. return path_proto
  33. def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
  34. node_proto.offset = max_index
  35. path.append([max_index,
  36. len(node_proto.word_ids) + len(node_proto.children), 0])
  37. max_index += len(node_proto.word_ids) + len(node_proto.children)
  38. if hierarchy_proto.size < max_index:
  39. hierarchy_proto.size = max_index
  40. for target, node in enumerate(node_proto.children):
  41. path[-1][2] = target
  42. max_index = recursive_path_builder(node, path, hierarchy_proto,
  43. max_index)
  44. for target, word in enumerate(node_proto.word_ids):
  45. path[-1][2] = target + len(node_proto.children)
  46. path_entry = create_path(path, word)
  47. new_path_entry = hierarchy_proto.paths.add()
  48. new_path_entry.MergeFrom(path_entry)
  49. del path[-1]
  50. return max_index
  51. node = tree_proto.root_node
  52. hierarchy_proto = hsm_pb2.HierarchyProto()
  53. path = []
  54. max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
  55. return hierarchy_proto