| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- ## @package text_file_reader
- # Module caffe2.python.text_file_reader
- from caffe2.python import core
- from caffe2.python.dataio import Reader
- from caffe2.python.schema import Scalar, Struct, data_type_for_dtype
- class TextFileReader(Reader):
- """
- Wrapper around operators for reading from text files.
- """
- def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1):
- """
- Create op for building a TextFileReader instance in the workspace.
- Args:
- init_net : Net that will be run only once at startup.
- filename : Path to file to read from.
- schema : schema.Struct representing the schema of the data.
- Currently, only support Struct of strings and float32.
- num_passes : Number of passes over the data.
- batch_size : Number of rows to read at a time.
- """
- assert isinstance(schema, Struct), 'Schema must be a schema.Struct'
- for name, child in schema.get_children():
- assert isinstance(child, Scalar), (
- 'Only scalar fields are supported in TextFileReader.')
- field_types = [
- data_type_for_dtype(dtype) for dtype in schema.field_types()]
- Reader.__init__(self, schema)
- self._reader = init_net.CreateTextFileReader(
- [],
- filename=filename,
- num_passes=num_passes,
- field_types=field_types)
- self._batch_size = batch_size
- def read(self, net):
- """
- Create op for reading a batch of rows.
- """
- blobs = net.TextFileReaderRead(
- [self._reader],
- len(self.schema().field_names()),
- batch_size=self._batch_size)
- if type(blobs) is core.BlobReference:
- blobs = [blobs]
- is_empty = net.IsEmpty(
- [blobs[0]],
- core.ScopedBlobReference(net.NextName('should_stop'))
- )
- return (is_empty, blobs)
|