text_file_reader.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. ## @package text_file_reader
  2. # Module caffe2.python.text_file_reader
  3. from caffe2.python import core
  4. from caffe2.python.dataio import Reader
  5. from caffe2.python.schema import Scalar, Struct, data_type_for_dtype
  6. class TextFileReader(Reader):
  7. """
  8. Wrapper around operators for reading from text files.
  9. """
  10. def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1):
  11. """
  12. Create op for building a TextFileReader instance in the workspace.
  13. Args:
  14. init_net : Net that will be run only once at startup.
  15. filename : Path to file to read from.
  16. schema : schema.Struct representing the schema of the data.
  17. Currently, only support Struct of strings and float32.
  18. num_passes : Number of passes over the data.
  19. batch_size : Number of rows to read at a time.
  20. """
  21. assert isinstance(schema, Struct), 'Schema must be a schema.Struct'
  22. for name, child in schema.get_children():
  23. assert isinstance(child, Scalar), (
  24. 'Only scalar fields are supported in TextFileReader.')
  25. field_types = [
  26. data_type_for_dtype(dtype) for dtype in schema.field_types()]
  27. Reader.__init__(self, schema)
  28. self._reader = init_net.CreateTextFileReader(
  29. [],
  30. filename=filename,
  31. num_passes=num_passes,
  32. field_types=field_types)
  33. self._batch_size = batch_size
  34. def read(self, net):
  35. """
  36. Create op for reading a batch of rows.
  37. """
  38. blobs = net.TextFileReaderRead(
  39. [self._reader],
  40. len(self.schema().field_names()),
  41. batch_size=self._batch_size)
  42. if type(blobs) is core.BlobReference:
  43. blobs = [blobs]
  44. is_empty = net.IsEmpty(
  45. [blobs[0]],
  46. core.ScopedBlobReference(net.NextName('should_stop'))
  47. )
  48. return (is_empty, blobs)