db_file_reader.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. ## @package db_file_reader
  2. # Module caffe2.python.db_file_reader
  3. from caffe2.python import core, scope, workspace, _import_c_extension as C
  4. from caffe2.python.dataio import Reader
  5. from caffe2.python.dataset import Dataset
  6. from caffe2.python.schema import from_column_list
  7. import os
  8. class DBFileReader(Reader):
  9. default_name_suffix = 'db_file_reader'
  10. """Reader reads from a DB file.
  11. Example usage:
  12. db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB')
  13. Args:
  14. db_path: str.
  15. db_type: str. DB type of file. A db_type is registed by
  16. `REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
  17. name: str or None. Name of DBFileReader.
  18. Optional name to prepend to blobs that will store the data.
  19. Default to '<db_name>_<default_name_suffix>'.
  20. batch_size: int.
  21. How many examples are read for each time the read_net is run.
  22. loop_over: bool.
  23. If True given, will go through examples in random order endlessly.
  24. field_names: List[str]. If the schema.field_names() should not in
  25. alphabetic order, it must be specified.
  26. Otherwise, schema will be automatically restored with
  27. schema.field_names() sorted in alphabetic order.
  28. """
  29. def __init__(
  30. self,
  31. db_path,
  32. db_type,
  33. name=None,
  34. batch_size=100,
  35. loop_over=False,
  36. field_names=None,
  37. ):
  38. assert db_path is not None, "db_path can't be None."
  39. assert db_type in C.registered_dbs(), \
  40. "db_type [{db_type}] is not available. \n" \
  41. "Choose one of these: {registered_dbs}.".format(
  42. db_type=db_type,
  43. registered_dbs=C.registered_dbs(),
  44. )
  45. self.db_path = os.path.expanduser(db_path)
  46. self.db_type = db_type
  47. self.name = name or '{db_name}_{default_name_suffix}'.format(
  48. db_name=self._extract_db_name_from_db_path(),
  49. default_name_suffix=self.default_name_suffix,
  50. )
  51. self.batch_size = batch_size
  52. self.loop_over = loop_over
  53. # Before self._init_reader_schema(...),
  54. # self.db_path and self.db_type are required to be set.
  55. super(DBFileReader, self).__init__(self._init_reader_schema(field_names))
  56. self.ds = Dataset(self._schema, self.name + '_dataset')
  57. self.ds_reader = None
  58. def _init_name(self, name):
  59. return name or self._extract_db_name_from_db_path(
  60. ) + '_db_file_reader'
  61. def _init_reader_schema(self, field_names=None):
  62. """Restore a reader schema from the DB file.
  63. If `field_names` given, restore scheme according to it.
  64. Overwise, loade blobs from the DB file into the workspace,
  65. and restore schema from these blob names.
  66. It is also assumed that:
  67. 1). Each field of the schema have corresponding blobs
  68. stored in the DB file.
  69. 2). Each blob loaded from the DB file corresponds to
  70. a field of the schema.
  71. 3). field_names in the original schema are in alphabetic order,
  72. since blob names loaded to the workspace from the DB file
  73. will be in alphabetic order.
  74. Load a set of blobs from a DB file. From names of these blobs,
  75. restore the DB file schema using `from_column_list(...)`.
  76. Returns:
  77. schema: schema.Struct. Used in Reader.__init__(...).
  78. """
  79. if field_names:
  80. return from_column_list(field_names)
  81. if self.db_type == "log_file_db":
  82. assert os.path.exists(self.db_path), \
  83. 'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
  84. with core.NameScope(self.name):
  85. # blob_prefix is for avoiding name conflict in workspace
  86. blob_prefix = scope.CurrentNameScope()
  87. workspace.RunOperatorOnce(
  88. core.CreateOperator(
  89. 'Load',
  90. [],
  91. [],
  92. absolute_path=True,
  93. db=self.db_path,
  94. db_type=self.db_type,
  95. load_all=True,
  96. add_prefix=blob_prefix,
  97. )
  98. )
  99. col_names = [
  100. blob_name[len(blob_prefix):] for blob_name in sorted(workspace.Blobs())
  101. if blob_name.startswith(blob_prefix)
  102. ]
  103. schema = from_column_list(col_names)
  104. return schema
  105. def setup_ex(self, init_net, finish_net):
  106. """From the Dataset, create a _DatasetReader and setup a init_net.
  107. Make sure the _init_field_blobs_as_empty(...) is only called once.
  108. Because the underlying NewRecord(...) creats blobs by calling
  109. NextScopedBlob(...), so that references to previously-initiated
  110. empty blobs will be lost, causing accessibility issue.
  111. """
  112. if self.ds_reader:
  113. self.ds_reader.setup_ex(init_net, finish_net)
  114. else:
  115. self._init_field_blobs_as_empty(init_net)
  116. self._feed_field_blobs_from_db_file(init_net)
  117. self.ds_reader = self.ds.random_reader(
  118. init_net,
  119. batch_size=self.batch_size,
  120. loop_over=self.loop_over,
  121. )
  122. self.ds_reader.sort_and_shuffle(init_net)
  123. self.ds_reader.computeoffset(init_net)
  124. def read(self, read_net):
  125. assert self.ds_reader, 'setup_ex must be called first'
  126. return self.ds_reader.read(read_net)
  127. def _init_field_blobs_as_empty(self, init_net):
  128. """Initialize dataset field blobs by creating an empty record"""
  129. with core.NameScope(self.name):
  130. self.ds.init_empty(init_net)
  131. def _feed_field_blobs_from_db_file(self, net):
  132. """Load from the DB file at db_path and feed dataset field blobs"""
  133. if self.db_type == "log_file_db":
  134. assert os.path.exists(self.db_path), \
  135. 'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
  136. net.Load(
  137. [],
  138. self.ds.get_blobs(),
  139. db=self.db_path,
  140. db_type=self.db_type,
  141. absolute_path=True,
  142. source_blob_names=self.ds.field_names(),
  143. )
  144. def _extract_db_name_from_db_path(self):
  145. """Extract DB name from DB path
  146. E.g. given self.db_path=`/tmp/sample.db`, or
  147. self.db_path = `dper_test_data/cached_reader/sample.db`
  148. it returns `sample`.
  149. Returns:
  150. db_name: str.
  151. """
  152. return os.path.basename(self.db_path).rsplit('.', 1)[0]