pipeline_test.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from caffe2.python.schema import (
  2. Struct, FetchRecord, NewRecord, FeedRecord, InitEmptyRecord)
  3. from caffe2.python import core, workspace
  4. from caffe2.python.session import LocalSession
  5. from caffe2.python.dataset import Dataset
  6. from caffe2.python.pipeline import pipe
  7. from caffe2.python.queue_util import Queue
  8. from caffe2.python.task import TaskGroup
  9. from caffe2.python.test_util import TestCase
  10. from caffe2.python.net_builder import ops
  11. import numpy as np
  12. import math
  13. class TestPipeline(TestCase):
  14. def test_dequeue_many(self):
  15. init_net = core.Net('init')
  16. N = 17
  17. NUM_DEQUEUE_RECORDS = 3
  18. src_values = Struct(
  19. ('uid', np.array(range(N))),
  20. ('value', 0.1 * np.array(range(N))))
  21. expected_dst = Struct(
  22. ('uid', 2 * np.array(range(N))),
  23. ('value', np.array(N * [0.0])))
  24. with core.NameScope('init'):
  25. src_blobs = NewRecord(init_net, src_values)
  26. dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())
  27. counter = init_net.Const(0)
  28. ONE = init_net.Const(1)
  29. def proc1(rec):
  30. with core.NameScope('proc1'):
  31. out = NewRecord(ops, rec)
  32. ops.Add([rec.uid(), rec.uid()], [out.uid()])
  33. out.value.set(blob=rec.value(), unsafe=True)
  34. return out
  35. def proc2(rec):
  36. with core.NameScope('proc2'):
  37. out = NewRecord(ops, rec)
  38. out.uid.set(blob=rec.uid(), unsafe=True)
  39. ops.Sub([rec.value(), rec.value()], [out.value()])
  40. ops.Add([counter, ONE], [counter])
  41. return out
  42. src_ds = Dataset(src_blobs)
  43. dst_ds = Dataset(dst_blobs)
  44. with TaskGroup() as tg:
  45. out1 = pipe(
  46. src_ds.reader(),
  47. output=Queue(
  48. capacity=11, num_dequeue_records=NUM_DEQUEUE_RECORDS),
  49. processor=proc1)
  50. out2 = pipe(out1, processor=proc2)
  51. pipe(out2, dst_ds.writer())
  52. ws = workspace.C.Workspace()
  53. FeedRecord(src_blobs, src_values, ws)
  54. session = LocalSession(ws)
  55. session.run(init_net)
  56. session.run(tg)
  57. output = FetchRecord(dst_blobs, ws=ws)
  58. num_dequeues = ws.blobs[str(counter)].fetch()
  59. self.assertEquals(
  60. num_dequeues, int(math.ceil(float(N) / NUM_DEQUEUE_RECORDS)))
  61. for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
  62. np.testing.assert_array_equal(a, b)