schema_test.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. from caffe2.python import core, schema
  2. import numpy as np
  3. import unittest
  4. import pickle
  5. import random
  6. class TestField(unittest.TestCase):
  7. def testInitShouldSetEmptyParent(self):
  8. f = schema.Field([])
  9. self.assertTupleEqual(f._parent, (None, 0))
  10. def testInitShouldSetFieldOffsets(self):
  11. f = schema.Field([
  12. schema.Scalar(dtype=np.int32),
  13. schema.Struct(
  14. ('field1', schema.Scalar(dtype=np.int32)),
  15. ('field2', schema.List(schema.Scalar(dtype=str))),
  16. ),
  17. schema.Scalar(dtype=np.int32),
  18. schema.Struct(
  19. ('field3', schema.Scalar(dtype=np.int32)),
  20. ('field4', schema.List(schema.Scalar(dtype=str)))
  21. ),
  22. schema.Scalar(dtype=np.int32),
  23. ])
  24. self.assertListEqual(f._field_offsets, [0, 1, 4, 5, 8, 9])
  25. def testInitShouldSetFieldOffsetsIfNoChildren(self):
  26. f = schema.Field([])
  27. self.assertListEqual(f._field_offsets, [0])
  28. class TestDB(unittest.TestCase):
  29. def testPicklable(self):
  30. s = schema.Struct(
  31. ('field1', schema.Scalar(dtype=np.int32)),
  32. ('field2', schema.List(schema.Scalar(dtype=str)))
  33. )
  34. s2 = pickle.loads(pickle.dumps(s))
  35. for r in (s, s2):
  36. self.assertTrue(isinstance(r.field1, schema.Scalar))
  37. self.assertTrue(isinstance(r.field2, schema.List))
  38. self.assertTrue(getattr(r, 'non_existent', None) is None)
  39. def testListSubclassClone(self):
  40. class Subclass(schema.List):
  41. pass
  42. s = Subclass(schema.Scalar())
  43. clone = s.clone()
  44. self.assertIsInstance(clone, Subclass)
  45. self.assertEqual(s, clone)
  46. self.assertIsNot(clone, s)
  47. def testListWithEvictedSubclassClone(self):
  48. class Subclass(schema.ListWithEvicted):
  49. pass
  50. s = Subclass(schema.Scalar())
  51. clone = s.clone()
  52. self.assertIsInstance(clone, Subclass)
  53. self.assertEqual(s, clone)
  54. self.assertIsNot(clone, s)
  55. def testStructSubclassClone(self):
  56. class Subclass(schema.Struct):
  57. pass
  58. s = Subclass(
  59. ('a', schema.Scalar()),
  60. )
  61. clone = s.clone()
  62. self.assertIsInstance(clone, Subclass)
  63. self.assertEqual(s, clone)
  64. self.assertIsNot(clone, s)
  65. def testNormalizeField(self):
  66. s = schema.Struct(('field1', np.int32), ('field2', str))
  67. self.assertEquals(
  68. s,
  69. schema.Struct(
  70. ('field1', schema.Scalar(dtype=np.int32)),
  71. ('field2', schema.Scalar(dtype=str))
  72. )
  73. )
  74. def testTuple(self):
  75. s = schema.Tuple(np.int32, str, np.float32)
  76. s2 = schema.Struct(
  77. ('field_0', schema.Scalar(dtype=np.int32)),
  78. ('field_1', schema.Scalar(dtype=np.str)),
  79. ('field_2', schema.Scalar(dtype=np.float32))
  80. )
  81. self.assertEquals(s, s2)
  82. self.assertEquals(s[0], schema.Scalar(dtype=np.int32))
  83. self.assertEquals(s[1], schema.Scalar(dtype=np.str))
  84. self.assertEquals(s[2], schema.Scalar(dtype=np.float32))
  85. self.assertEquals(
  86. s[2, 0],
  87. schema.Struct(
  88. ('field_2', schema.Scalar(dtype=np.float32)),
  89. ('field_0', schema.Scalar(dtype=np.int32)),
  90. )
  91. )
  92. # test iterator behavior
  93. for i, (v1, v2) in enumerate(zip(s, s2)):
  94. self.assertEquals(v1, v2)
  95. self.assertEquals(s[i], v1)
  96. self.assertEquals(s2[i], v1)
  97. def testRawTuple(self):
  98. s = schema.RawTuple(2)
  99. self.assertEquals(
  100. s, schema.Struct(
  101. ('field_0', schema.Scalar()), ('field_1', schema.Scalar())
  102. )
  103. )
  104. self.assertEquals(s[0], schema.Scalar())
  105. self.assertEquals(s[1], schema.Scalar())
  106. def testStructIndexing(self):
  107. s = schema.Struct(
  108. ('field1', schema.Scalar(dtype=np.int32)),
  109. ('field2', schema.List(schema.Scalar(dtype=str))),
  110. ('field3', schema.Struct()),
  111. )
  112. self.assertEquals(s['field2'], s.field2)
  113. self.assertEquals(s['field2'], schema.List(schema.Scalar(dtype=str)))
  114. self.assertEquals(s['field3'], schema.Struct())
  115. self.assertEquals(
  116. s['field2', 'field1'],
  117. schema.Struct(
  118. ('field2', schema.List(schema.Scalar(dtype=str))),
  119. ('field1', schema.Scalar(dtype=np.int32)),
  120. )
  121. )
  122. def testListInStructIndexing(self):
  123. a = schema.List(schema.Scalar(dtype=str))
  124. s = schema.Struct(
  125. ('field1', schema.Scalar(dtype=np.int32)),
  126. ('field2', a)
  127. )
  128. self.assertEquals(s['field2:lengths'], a.lengths)
  129. self.assertEquals(s['field2:values'], a.items)
  130. with self.assertRaises(KeyError):
  131. s['fields2:items:non_existent']
  132. with self.assertRaises(KeyError):
  133. s['fields2:non_existent']
  134. def testListWithEvictedInStructIndexing(self):
  135. a = schema.ListWithEvicted(schema.Scalar(dtype=str))
  136. s = schema.Struct(
  137. ('field1', schema.Scalar(dtype=np.int32)),
  138. ('field2', a)
  139. )
  140. self.assertEquals(s['field2:lengths'], a.lengths)
  141. self.assertEquals(s['field2:values'], a.items)
  142. self.assertEquals(s['field2:_evicted_values'], a._evicted_values)
  143. with self.assertRaises(KeyError):
  144. s['fields2:items:non_existent']
  145. with self.assertRaises(KeyError):
  146. s['fields2:non_existent']
  147. def testMapInStructIndexing(self):
  148. a = schema.Map(
  149. schema.Scalar(dtype=np.int32),
  150. schema.Scalar(dtype=np.float32),
  151. )
  152. s = schema.Struct(
  153. ('field1', schema.Scalar(dtype=np.int32)),
  154. ('field2', a)
  155. )
  156. self.assertEquals(s['field2:values:keys'], a.keys)
  157. self.assertEquals(s['field2:values:values'], a.values)
  158. with self.assertRaises(KeyError):
  159. s['fields2:keys:non_existent']
  160. def testPreservesMetadata(self):
  161. s = schema.Struct(
  162. ('a', schema.Scalar(np.float32)), (
  163. 'b', schema.Scalar(
  164. np.int32,
  165. metadata=schema.Metadata(categorical_limit=5)
  166. )
  167. ), (
  168. 'c', schema.List(
  169. schema.Scalar(
  170. np.int32,
  171. metadata=schema.Metadata(categorical_limit=6)
  172. )
  173. )
  174. )
  175. )
  176. # attach metadata to lengths field
  177. s.c.lengths.set_metadata(schema.Metadata(categorical_limit=7))
  178. self.assertEqual(None, s.a.metadata)
  179. self.assertEqual(5, s.b.metadata.categorical_limit)
  180. self.assertEqual(6, s.c.value.metadata.categorical_limit)
  181. self.assertEqual(7, s.c.lengths.metadata.categorical_limit)
  182. sc = s.clone()
  183. self.assertEqual(None, sc.a.metadata)
  184. self.assertEqual(5, sc.b.metadata.categorical_limit)
  185. self.assertEqual(6, sc.c.value.metadata.categorical_limit)
  186. self.assertEqual(7, sc.c.lengths.metadata.categorical_limit)
  187. sv = schema.from_blob_list(
  188. s, [
  189. np.array([3.4]), np.array([2]), np.array([3]),
  190. np.array([1, 2, 3])
  191. ]
  192. )
  193. self.assertEqual(None, sv.a.metadata)
  194. self.assertEqual(5, sv.b.metadata.categorical_limit)
  195. self.assertEqual(6, sv.c.value.metadata.categorical_limit)
  196. self.assertEqual(7, sv.c.lengths.metadata.categorical_limit)
  197. def testDupField(self):
  198. with self.assertRaises(ValueError):
  199. schema.Struct(
  200. ('a', schema.Scalar()),
  201. ('a', schema.Scalar()))
  202. def testAssignToField(self):
  203. with self.assertRaises(TypeError):
  204. s = schema.Struct(('a', schema.Scalar()))
  205. s.a = schema.Scalar()
  206. def testPreservesEmptyFields(self):
  207. s = schema.Struct(
  208. ('a', schema.Scalar(np.float32)),
  209. ('b', schema.Struct()),
  210. )
  211. sc = s.clone()
  212. self.assertIn("a", sc.fields)
  213. self.assertIn("b", sc.fields)
  214. sv = schema.from_blob_list(s, [np.array([3.4])])
  215. self.assertIn("a", sv.fields)
  216. self.assertIn("b", sv.fields)
  217. self.assertEqual(0, len(sv.b.fields))
  218. def testStructSubstraction(self):
  219. s1 = schema.Struct(
  220. ('a', schema.Scalar()),
  221. ('b', schema.Scalar()),
  222. ('c', schema.Scalar()),
  223. )
  224. s2 = schema.Struct(
  225. ('b', schema.Scalar())
  226. )
  227. s = s1 - s2
  228. self.assertEqual(['a', 'c'], s.field_names())
  229. s3 = schema.Struct(
  230. ('a', schema.Scalar())
  231. )
  232. s = s1 - s3
  233. self.assertEqual(['b', 'c'], s.field_names())
  234. with self.assertRaises(TypeError):
  235. s1 - schema.Scalar()
  236. def testStructNestedSubstraction(self):
  237. s1 = schema.Struct(
  238. ('a', schema.Scalar()),
  239. ('b', schema.Struct(
  240. ('c', schema.Scalar()),
  241. ('d', schema.Scalar()),
  242. ('e', schema.Scalar()),
  243. ('f', schema.Scalar()),
  244. )),
  245. )
  246. s2 = schema.Struct(
  247. ('b', schema.Struct(
  248. ('d', schema.Scalar()),
  249. ('e', schema.Scalar()),
  250. )),
  251. )
  252. s = s1 - s2
  253. self.assertEqual(['a', 'b:c', 'b:f'], s.field_names())
  254. def testStructAddition(self):
  255. s1 = schema.Struct(
  256. ('a', schema.Scalar())
  257. )
  258. s2 = schema.Struct(
  259. ('b', schema.Scalar())
  260. )
  261. s = s1 + s2
  262. self.assertIn("a", s.fields)
  263. self.assertIn("b", s.fields)
  264. with self.assertRaises(TypeError):
  265. s1 + s1
  266. with self.assertRaises(TypeError):
  267. s1 + schema.Scalar()
  268. def testStructNestedAddition(self):
  269. s1 = schema.Struct(
  270. ('a', schema.Scalar()),
  271. ('b', schema.Struct(
  272. ('c', schema.Scalar())
  273. )),
  274. )
  275. s2 = schema.Struct(
  276. ('b', schema.Struct(
  277. ('d', schema.Scalar())
  278. ))
  279. )
  280. s = s1 + s2
  281. self.assertEqual(['a', 'b:c', 'b:d'], s.field_names())
  282. s3 = schema.Struct(
  283. ('b', schema.Scalar()),
  284. )
  285. with self.assertRaises(TypeError):
  286. s = s1 + s3
  287. def testGetFieldByNestedName(self):
  288. st = schema.Struct(
  289. ('a', schema.Scalar()),
  290. ('b', schema.Struct(
  291. ('c', schema.Struct(
  292. ('d', schema.Scalar()),
  293. )),
  294. )),
  295. )
  296. self.assertRaises(KeyError, st.__getitem__, '')
  297. self.assertRaises(KeyError, st.__getitem__, 'x')
  298. self.assertRaises(KeyError, st.__getitem__, 'x:y')
  299. self.assertRaises(KeyError, st.__getitem__, 'b:c:x')
  300. a = st['a']
  301. self.assertTrue(isinstance(a, schema.Scalar))
  302. bc = st['b:c']
  303. self.assertIn('d', bc.fields)
  304. bcd = st['b:c:d']
  305. self.assertTrue(isinstance(bcd, schema.Scalar))
  306. def testAddFieldByNestedName(self):
  307. f_a = schema.Scalar(blob=core.BlobReference('blob1'))
  308. f_b = schema.Struct(
  309. ('c', schema.Struct(
  310. ('d', schema.Scalar(blob=core.BlobReference('blob2'))),
  311. )),
  312. )
  313. f_x = schema.Struct(
  314. ('x', schema.Scalar(blob=core.BlobReference('blob3'))),
  315. )
  316. with self.assertRaises(TypeError):
  317. st = schema.Struct(
  318. ('a', f_a),
  319. ('b', f_b),
  320. ('b:c:d', f_x),
  321. )
  322. with self.assertRaises(TypeError):
  323. st = schema.Struct(
  324. ('a', f_a),
  325. ('b', f_b),
  326. ('b:c:d:e', f_x),
  327. )
  328. st = schema.Struct(
  329. ('a', f_a),
  330. ('b', f_b),
  331. ('e:f', f_x),
  332. )
  333. self.assertEqual(['a', 'b:c:d', 'e:f:x'], st.field_names())
  334. self.assertEqual(['blob1', 'blob2', 'blob3'], st.field_blobs())
  335. st = schema.Struct(
  336. ('a', f_a),
  337. ('b:c:e', f_x),
  338. ('b', f_b),
  339. )
  340. self.assertEqual(['a', 'b:c:e:x', 'b:c:d'], st.field_names())
  341. self.assertEqual(['blob1', 'blob3', 'blob2'], st.field_blobs())
  342. st = schema.Struct(
  343. ('a:a1', f_a),
  344. ('b:b1', f_b),
  345. ('a', f_x),
  346. )
  347. self.assertEqual(['a:a1', 'a:x', 'b:b1:c:d'], st.field_names())
  348. self.assertEqual(['blob1', 'blob3', 'blob2'], st.field_blobs())
  349. def testContains(self):
  350. st = schema.Struct(
  351. ('a', schema.Scalar()),
  352. ('b', schema.Struct(
  353. ('c', schema.Struct(
  354. ('d', schema.Scalar()),
  355. )),
  356. )),
  357. )
  358. self.assertTrue('a' in st)
  359. self.assertTrue('b:c' in st)
  360. self.assertTrue('b:c:d' in st)
  361. self.assertFalse('' in st)
  362. self.assertFalse('x' in st)
  363. self.assertFalse('b:c:x' in st)
  364. self.assertFalse('b:c:d:x' in st)
  365. def testFromEmptyColumnList(self):
  366. st = schema.Struct()
  367. columns = st.field_names()
  368. rec = schema.from_column_list(col_names=columns)
  369. self.assertEqual(rec, schema.Struct())
  370. def testFromColumnList(self):
  371. st = schema.Struct(
  372. ('a', schema.Scalar()),
  373. ('b', schema.List(schema.Scalar())),
  374. ('c', schema.Map(schema.Scalar(), schema.Scalar()))
  375. )
  376. columns = st.field_names()
  377. # test that recovery works for arbitrary order
  378. for _ in range(10):
  379. some_blobs = [core.BlobReference('blob:' + x) for x in columns]
  380. rec = schema.from_column_list(columns, col_blobs=some_blobs)
  381. self.assertTrue(rec.has_blobs())
  382. self.assertEqual(sorted(st.field_names()), sorted(rec.field_names()))
  383. self.assertEqual([str(blob) for blob in rec.field_blobs()],
  384. [str('blob:' + name) for name in rec.field_names()])
  385. random.shuffle(columns)
  386. def testStructGet(self):
  387. net = core.Net('test_net')
  388. s1 = schema.NewRecord(net, schema.Scalar(np.float32))
  389. s2 = schema.NewRecord(net, schema.Scalar(np.float32))
  390. t = schema.Tuple(s1, s2)
  391. assert t.get('field_0', None) == s1
  392. assert t.get('field_1', None) == s2
  393. assert t.get('field_2', None) is None
  394. def testScalarForVoidType(self):
  395. s0_good = schema.Scalar((None, (2, )))
  396. with self.assertRaises(TypeError):
  397. s0_bad = schema.Scalar((np.void, (2, )))
  398. s1_good = schema.Scalar(np.void)
  399. s2_good = schema.Scalar(None)
  400. assert s1_good == s2_good
  401. def testScalarShape(self):
  402. s0 = schema.Scalar(np.int32)
  403. self.assertEqual(s0.field_type().shape, ())
  404. s1_good = schema.Scalar((np.int32, 5))
  405. self.assertEqual(s1_good.field_type().shape, (5, ))
  406. with self.assertRaises(ValueError):
  407. s1_bad = schema.Scalar((np.int32, -1))
  408. s1_hard = schema.Scalar((np.int32, 1))
  409. self.assertEqual(s1_hard.field_type().shape, (1, ))
  410. s2 = schema.Scalar((np.int32, (2, 3)))
  411. self.assertEqual(s2.field_type().shape, (2, 3))
  412. def testDtypeForCoreType(self):
  413. dtype = schema.dtype_for_core_type(core.DataType.FLOAT16)
  414. self.assertEqual(dtype, np.float16)
  415. with self.assertRaises(TypeError):
  416. schema.dtype_for_core_type(100)