| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- ## @package hsm_util
- # Module caffe2.python.hsm_util
- from caffe2.proto import hsm_pb2
- '''
- Hierarchical softmax utility methods that can be used to:
- 1) create TreeProto structure given list of word_ids or NodeProtos
- 2) create HierarchyProto structure using the user-inputted TreeProto
- '''
- def create_node_with_words(words, name='node'):
- node = hsm_pb2.NodeProto()
- node.name = name
- for word in words:
- node.word_ids.append(word)
- return node
- def create_node_with_nodes(nodes, name='node'):
- node = hsm_pb2.NodeProto()
- node.name = name
- for child_node in nodes:
- new_child_node = node.children.add()
- new_child_node.MergeFrom(child_node)
- return node
- def create_hierarchy(tree_proto):
- max_index = 0
- def create_path(path, word):
- path_proto = hsm_pb2.PathProto()
- path_proto.word_id = word
- for entry in path:
- new_path_node = path_proto.path_nodes.add()
- new_path_node.index = entry[0]
- new_path_node.length = entry[1]
- new_path_node.target = entry[2]
- return path_proto
- def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
- node_proto.offset = max_index
- path.append([max_index,
- len(node_proto.word_ids) + len(node_proto.children), 0])
- max_index += len(node_proto.word_ids) + len(node_proto.children)
- if hierarchy_proto.size < max_index:
- hierarchy_proto.size = max_index
- for target, node in enumerate(node_proto.children):
- path[-1][2] = target
- max_index = recursive_path_builder(node, path, hierarchy_proto,
- max_index)
- for target, word in enumerate(node_proto.word_ids):
- path[-1][2] = target + len(node_proto.children)
- path_entry = create_path(path, word)
- new_path_entry = hierarchy_proto.paths.add()
- new_path_entry.MergeFrom(path_entry)
- del path[-1]
- return max_index
- node = tree_proto.root_node
- hierarchy_proto = hsm_pb2.HierarchyProto()
- path = []
- max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
- return hierarchy_proto
|