| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- from caffe2.python import core, schema
- import numpy as np
- import unittest
- import pickle
- import random
- class TestField(unittest.TestCase):
- def testInitShouldSetEmptyParent(self):
- f = schema.Field([])
- self.assertTupleEqual(f._parent, (None, 0))
- def testInitShouldSetFieldOffsets(self):
- f = schema.Field([
- schema.Scalar(dtype=np.int32),
- schema.Struct(
- ('field1', schema.Scalar(dtype=np.int32)),
- ('field2', schema.List(schema.Scalar(dtype=str))),
- ),
- schema.Scalar(dtype=np.int32),
- schema.Struct(
- ('field3', schema.Scalar(dtype=np.int32)),
- ('field4', schema.List(schema.Scalar(dtype=str)))
- ),
- schema.Scalar(dtype=np.int32),
- ])
- self.assertListEqual(f._field_offsets, [0, 1, 4, 5, 8, 9])
- def testInitShouldSetFieldOffsetsIfNoChildren(self):
- f = schema.Field([])
- self.assertListEqual(f._field_offsets, [0])
- class TestDB(unittest.TestCase):
- def testPicklable(self):
- s = schema.Struct(
- ('field1', schema.Scalar(dtype=np.int32)),
- ('field2', schema.List(schema.Scalar(dtype=str)))
- )
- s2 = pickle.loads(pickle.dumps(s))
- for r in (s, s2):
- self.assertTrue(isinstance(r.field1, schema.Scalar))
- self.assertTrue(isinstance(r.field2, schema.List))
- self.assertTrue(getattr(r, 'non_existent', None) is None)
- def testListSubclassClone(self):
- class Subclass(schema.List):
- pass
- s = Subclass(schema.Scalar())
- clone = s.clone()
- self.assertIsInstance(clone, Subclass)
- self.assertEqual(s, clone)
- self.assertIsNot(clone, s)
- def testListWithEvictedSubclassClone(self):
- class Subclass(schema.ListWithEvicted):
- pass
- s = Subclass(schema.Scalar())
- clone = s.clone()
- self.assertIsInstance(clone, Subclass)
- self.assertEqual(s, clone)
- self.assertIsNot(clone, s)
- def testStructSubclassClone(self):
- class Subclass(schema.Struct):
- pass
- s = Subclass(
- ('a', schema.Scalar()),
- )
- clone = s.clone()
- self.assertIsInstance(clone, Subclass)
- self.assertEqual(s, clone)
- self.assertIsNot(clone, s)
- def testNormalizeField(self):
- s = schema.Struct(('field1', np.int32), ('field2', str))
- self.assertEquals(
- s,
- schema.Struct(
- ('field1', schema.Scalar(dtype=np.int32)),
- ('field2', schema.Scalar(dtype=str))
- )
- )
- def testTuple(self):
- s = schema.Tuple(np.int32, str, np.float32)
- s2 = schema.Struct(
- ('field_0', schema.Scalar(dtype=np.int32)),
- ('field_1', schema.Scalar(dtype=np.str)),
- ('field_2', schema.Scalar(dtype=np.float32))
- )
- self.assertEquals(s, s2)
- self.assertEquals(s[0], schema.Scalar(dtype=np.int32))
- self.assertEquals(s[1], schema.Scalar(dtype=np.str))
- self.assertEquals(s[2], schema.Scalar(dtype=np.float32))
- self.assertEquals(
- s[2, 0],
- schema.Struct(
- ('field_2', schema.Scalar(dtype=np.float32)),
- ('field_0', schema.Scalar(dtype=np.int32)),
- )
- )
- # test iterator behavior
- for i, (v1, v2) in enumerate(zip(s, s2)):
- self.assertEquals(v1, v2)
- self.assertEquals(s[i], v1)
- self.assertEquals(s2[i], v1)
- def testRawTuple(self):
- s = schema.RawTuple(2)
- self.assertEquals(
- s, schema.Struct(
- ('field_0', schema.Scalar()), ('field_1', schema.Scalar())
- )
- )
- self.assertEquals(s[0], schema.Scalar())
- self.assertEquals(s[1], schema.Scalar())
- def testStructIndexing(self):
- s = schema.Struct(
- ('field1', schema.Scalar(dtype=np.int32)),
- ('field2', schema.List(schema.Scalar(dtype=str))),
- ('field3', schema.Struct()),
- )
- self.assertEquals(s['field2'], s.field2)
- self.assertEquals(s['field2'], schema.List(schema.Scalar(dtype=str)))
- self.assertEquals(s['field3'], schema.Struct())
- self.assertEquals(
- s['field2', 'field1'],
- schema.Struct(
- ('field2', schema.List(schema.Scalar(dtype=str))),
- ('field1', schema.Scalar(dtype=np.int32)),
- )
- )
- def testListInStructIndexing(self):
- a = schema.List(schema.Scalar(dtype=str))
- s = schema.Struct(
- ('field1', schema.Scalar(dtype=np.int32)),
- ('field2', a)
- )
- self.assertEquals(s['field2:lengths'], a.lengths)
- self.assertEquals(s['field2:values'], a.items)
- with self.assertRaises(KeyError):
- s['fields2:items:non_existent']
- with self.assertRaises(KeyError):
- s['fields2:non_existent']
- def testListWithEvictedInStructIndexing(self):
- a = schema.ListWithEvicted(schema.Scalar(dtype=str))
- s = schema.Struct(
- ('field1', schema.Scalar(dtype=np.int32)),
- ('field2', a)
- )
- self.assertEquals(s['field2:lengths'], a.lengths)
- self.assertEquals(s['field2:values'], a.items)
- self.assertEquals(s['field2:_evicted_values'], a._evicted_values)
- with self.assertRaises(KeyError):
- s['fields2:items:non_existent']
- with self.assertRaises(KeyError):
- s['fields2:non_existent']
- def testMapInStructIndexing(self):
- a = schema.Map(
- schema.Scalar(dtype=np.int32),
- schema.Scalar(dtype=np.float32),
- )
- s = schema.Struct(
- ('field1', schema.Scalar(dtype=np.int32)),
- ('field2', a)
- )
- self.assertEquals(s['field2:values:keys'], a.keys)
- self.assertEquals(s['field2:values:values'], a.values)
- with self.assertRaises(KeyError):
- s['fields2:keys:non_existent']
- def testPreservesMetadata(self):
- s = schema.Struct(
- ('a', schema.Scalar(np.float32)), (
- 'b', schema.Scalar(
- np.int32,
- metadata=schema.Metadata(categorical_limit=5)
- )
- ), (
- 'c', schema.List(
- schema.Scalar(
- np.int32,
- metadata=schema.Metadata(categorical_limit=6)
- )
- )
- )
- )
- # attach metadata to lengths field
- s.c.lengths.set_metadata(schema.Metadata(categorical_limit=7))
- self.assertEqual(None, s.a.metadata)
- self.assertEqual(5, s.b.metadata.categorical_limit)
- self.assertEqual(6, s.c.value.metadata.categorical_limit)
- self.assertEqual(7, s.c.lengths.metadata.categorical_limit)
- sc = s.clone()
- self.assertEqual(None, sc.a.metadata)
- self.assertEqual(5, sc.b.metadata.categorical_limit)
- self.assertEqual(6, sc.c.value.metadata.categorical_limit)
- self.assertEqual(7, sc.c.lengths.metadata.categorical_limit)
- sv = schema.from_blob_list(
- s, [
- np.array([3.4]), np.array([2]), np.array([3]),
- np.array([1, 2, 3])
- ]
- )
- self.assertEqual(None, sv.a.metadata)
- self.assertEqual(5, sv.b.metadata.categorical_limit)
- self.assertEqual(6, sv.c.value.metadata.categorical_limit)
- self.assertEqual(7, sv.c.lengths.metadata.categorical_limit)
- def testDupField(self):
- with self.assertRaises(ValueError):
- schema.Struct(
- ('a', schema.Scalar()),
- ('a', schema.Scalar()))
- def testAssignToField(self):
- with self.assertRaises(TypeError):
- s = schema.Struct(('a', schema.Scalar()))
- s.a = schema.Scalar()
- def testPreservesEmptyFields(self):
- s = schema.Struct(
- ('a', schema.Scalar(np.float32)),
- ('b', schema.Struct()),
- )
- sc = s.clone()
- self.assertIn("a", sc.fields)
- self.assertIn("b", sc.fields)
- sv = schema.from_blob_list(s, [np.array([3.4])])
- self.assertIn("a", sv.fields)
- self.assertIn("b", sv.fields)
- self.assertEqual(0, len(sv.b.fields))
- def testStructSubstraction(self):
- s1 = schema.Struct(
- ('a', schema.Scalar()),
- ('b', schema.Scalar()),
- ('c', schema.Scalar()),
- )
- s2 = schema.Struct(
- ('b', schema.Scalar())
- )
- s = s1 - s2
- self.assertEqual(['a', 'c'], s.field_names())
- s3 = schema.Struct(
- ('a', schema.Scalar())
- )
- s = s1 - s3
- self.assertEqual(['b', 'c'], s.field_names())
- with self.assertRaises(TypeError):
- s1 - schema.Scalar()
- def testStructNestedSubstraction(self):
- s1 = schema.Struct(
- ('a', schema.Scalar()),
- ('b', schema.Struct(
- ('c', schema.Scalar()),
- ('d', schema.Scalar()),
- ('e', schema.Scalar()),
- ('f', schema.Scalar()),
- )),
- )
- s2 = schema.Struct(
- ('b', schema.Struct(
- ('d', schema.Scalar()),
- ('e', schema.Scalar()),
- )),
- )
- s = s1 - s2
- self.assertEqual(['a', 'b:c', 'b:f'], s.field_names())
- def testStructAddition(self):
- s1 = schema.Struct(
- ('a', schema.Scalar())
- )
- s2 = schema.Struct(
- ('b', schema.Scalar())
- )
- s = s1 + s2
- self.assertIn("a", s.fields)
- self.assertIn("b", s.fields)
- with self.assertRaises(TypeError):
- s1 + s1
- with self.assertRaises(TypeError):
- s1 + schema.Scalar()
- def testStructNestedAddition(self):
- s1 = schema.Struct(
- ('a', schema.Scalar()),
- ('b', schema.Struct(
- ('c', schema.Scalar())
- )),
- )
- s2 = schema.Struct(
- ('b', schema.Struct(
- ('d', schema.Scalar())
- ))
- )
- s = s1 + s2
- self.assertEqual(['a', 'b:c', 'b:d'], s.field_names())
- s3 = schema.Struct(
- ('b', schema.Scalar()),
- )
- with self.assertRaises(TypeError):
- s = s1 + s3
- def testGetFieldByNestedName(self):
- st = schema.Struct(
- ('a', schema.Scalar()),
- ('b', schema.Struct(
- ('c', schema.Struct(
- ('d', schema.Scalar()),
- )),
- )),
- )
- self.assertRaises(KeyError, st.__getitem__, '')
- self.assertRaises(KeyError, st.__getitem__, 'x')
- self.assertRaises(KeyError, st.__getitem__, 'x:y')
- self.assertRaises(KeyError, st.__getitem__, 'b:c:x')
- a = st['a']
- self.assertTrue(isinstance(a, schema.Scalar))
- bc = st['b:c']
- self.assertIn('d', bc.fields)
- bcd = st['b:c:d']
- self.assertTrue(isinstance(bcd, schema.Scalar))
- def testAddFieldByNestedName(self):
- f_a = schema.Scalar(blob=core.BlobReference('blob1'))
- f_b = schema.Struct(
- ('c', schema.Struct(
- ('d', schema.Scalar(blob=core.BlobReference('blob2'))),
- )),
- )
- f_x = schema.Struct(
- ('x', schema.Scalar(blob=core.BlobReference('blob3'))),
- )
- with self.assertRaises(TypeError):
- st = schema.Struct(
- ('a', f_a),
- ('b', f_b),
- ('b:c:d', f_x),
- )
- with self.assertRaises(TypeError):
- st = schema.Struct(
- ('a', f_a),
- ('b', f_b),
- ('b:c:d:e', f_x),
- )
- st = schema.Struct(
- ('a', f_a),
- ('b', f_b),
- ('e:f', f_x),
- )
- self.assertEqual(['a', 'b:c:d', 'e:f:x'], st.field_names())
- self.assertEqual(['blob1', 'blob2', 'blob3'], st.field_blobs())
- st = schema.Struct(
- ('a', f_a),
- ('b:c:e', f_x),
- ('b', f_b),
- )
- self.assertEqual(['a', 'b:c:e:x', 'b:c:d'], st.field_names())
- self.assertEqual(['blob1', 'blob3', 'blob2'], st.field_blobs())
- st = schema.Struct(
- ('a:a1', f_a),
- ('b:b1', f_b),
- ('a', f_x),
- )
- self.assertEqual(['a:a1', 'a:x', 'b:b1:c:d'], st.field_names())
- self.assertEqual(['blob1', 'blob3', 'blob2'], st.field_blobs())
- def testContains(self):
- st = schema.Struct(
- ('a', schema.Scalar()),
- ('b', schema.Struct(
- ('c', schema.Struct(
- ('d', schema.Scalar()),
- )),
- )),
- )
- self.assertTrue('a' in st)
- self.assertTrue('b:c' in st)
- self.assertTrue('b:c:d' in st)
- self.assertFalse('' in st)
- self.assertFalse('x' in st)
- self.assertFalse('b:c:x' in st)
- self.assertFalse('b:c:d:x' in st)
- def testFromEmptyColumnList(self):
- st = schema.Struct()
- columns = st.field_names()
- rec = schema.from_column_list(col_names=columns)
- self.assertEqual(rec, schema.Struct())
- def testFromColumnList(self):
- st = schema.Struct(
- ('a', schema.Scalar()),
- ('b', schema.List(schema.Scalar())),
- ('c', schema.Map(schema.Scalar(), schema.Scalar()))
- )
- columns = st.field_names()
- # test that recovery works for arbitrary order
- for _ in range(10):
- some_blobs = [core.BlobReference('blob:' + x) for x in columns]
- rec = schema.from_column_list(columns, col_blobs=some_blobs)
- self.assertTrue(rec.has_blobs())
- self.assertEqual(sorted(st.field_names()), sorted(rec.field_names()))
- self.assertEqual([str(blob) for blob in rec.field_blobs()],
- [str('blob:' + name) for name in rec.field_names()])
- random.shuffle(columns)
- def testStructGet(self):
- net = core.Net('test_net')
- s1 = schema.NewRecord(net, schema.Scalar(np.float32))
- s2 = schema.NewRecord(net, schema.Scalar(np.float32))
- t = schema.Tuple(s1, s2)
- assert t.get('field_0', None) == s1
- assert t.get('field_1', None) == s2
- assert t.get('field_2', None) is None
- def testScalarForVoidType(self):
- s0_good = schema.Scalar((None, (2, )))
- with self.assertRaises(TypeError):
- s0_bad = schema.Scalar((np.void, (2, )))
- s1_good = schema.Scalar(np.void)
- s2_good = schema.Scalar(None)
- assert s1_good == s2_good
- def testScalarShape(self):
- s0 = schema.Scalar(np.int32)
- self.assertEqual(s0.field_type().shape, ())
- s1_good = schema.Scalar((np.int32, 5))
- self.assertEqual(s1_good.field_type().shape, (5, ))
- with self.assertRaises(ValueError):
- s1_bad = schema.Scalar((np.int32, -1))
- s1_hard = schema.Scalar((np.int32, 1))
- self.assertEqual(s1_hard.field_type().shape, (1, ))
- s2 = schema.Scalar((np.int32, (2, 3)))
- self.assertEqual(s2.field_type().shape, (2, 3))
- def testDtypeForCoreType(self):
- dtype = schema.dtype_for_core_type(core.DataType.FLOAT16)
- self.assertEqual(dtype, np.float16)
- with self.assertRaises(TypeError):
- schema.dtype_for_core_type(100)
|