session_test.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from caffe2.python.schema import (
  2. Struct, FetchRecord, NewRecord, FeedRecord, InitEmptyRecord)
  3. from caffe2.python import core, workspace
  4. from caffe2.python.session import LocalSession
  5. from caffe2.python.dataset import Dataset
  6. from caffe2.python.pipeline import pipe
  7. from caffe2.python.task import TaskGroup
  8. from caffe2.python.test_util import TestCase
  9. import numpy as np
  10. class TestLocalSession(TestCase):
  11. def test_local_session(self):
  12. init_net = core.Net('init')
  13. src_values = Struct(
  14. ('uid', np.array([1, 2, 6])),
  15. ('value', np.array([1.4, 1.6, 1.7])))
  16. expected_dst = Struct(
  17. ('uid', np.array([2, 4, 12])),
  18. ('value', np.array([0.0, 0.0, 0.0])))
  19. with core.NameScope('init'):
  20. src_blobs = NewRecord(init_net, src_values)
  21. dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())
  22. def proc1(rec):
  23. net = core.Net('proc1')
  24. with core.NameScope('proc1'):
  25. out = NewRecord(net, rec)
  26. net.Add([rec.uid(), rec.uid()], [out.uid()])
  27. out.value.set(blob=rec.value(), unsafe=True)
  28. return [net], out
  29. def proc2(rec):
  30. net = core.Net('proc2')
  31. with core.NameScope('proc2'):
  32. out = NewRecord(net, rec)
  33. out.uid.set(blob=rec.uid(), unsafe=True)
  34. net.Sub([rec.value(), rec.value()], [out.value()])
  35. return [net], out
  36. src_ds = Dataset(src_blobs)
  37. dst_ds = Dataset(dst_blobs)
  38. with TaskGroup() as tg:
  39. out1 = pipe(src_ds.reader(), processor=proc1)
  40. out2 = pipe(out1, processor=proc2)
  41. pipe(out2, dst_ds.writer())
  42. ws = workspace.C.Workspace()
  43. FeedRecord(src_blobs, src_values, ws)
  44. session = LocalSession(ws)
  45. session.run(init_net)
  46. session.run(tg)
  47. output = FetchRecord(dst_blobs, ws=ws)
  48. for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
  49. np.testing.assert_array_equal(a, b)