voc.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import collections
  2. import os
  3. from xml.etree.ElementTree import Element as ET_Element
  4. from .vision import VisionDataset
  5. try:
  6. from defusedxml.ElementTree import parse as ET_parse
  7. except ImportError:
  8. from xml.etree.ElementTree import parse as ET_parse
  9. import warnings
  10. from typing import Any, Callable, Dict, Optional, Tuple, List
  11. from PIL import Image
  12. from .utils import download_and_extract_archive, verify_str_arg
  13. DATASET_YEAR_DICT = {
  14. "2012": {
  15. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
  16. "filename": "VOCtrainval_11-May-2012.tar",
  17. "md5": "6cd6e144f989b92b3379bac3b3de84fd",
  18. "base_dir": os.path.join("VOCdevkit", "VOC2012"),
  19. },
  20. "2011": {
  21. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
  22. "filename": "VOCtrainval_25-May-2011.tar",
  23. "md5": "6c3384ef61512963050cb5d687e5bf1e",
  24. "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
  25. },
  26. "2010": {
  27. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
  28. "filename": "VOCtrainval_03-May-2010.tar",
  29. "md5": "da459979d0c395079b5c75ee67908abb",
  30. "base_dir": os.path.join("VOCdevkit", "VOC2010"),
  31. },
  32. "2009": {
  33. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
  34. "filename": "VOCtrainval_11-May-2009.tar",
  35. "md5": "a3e00b113cfcfebf17e343f59da3caa1",
  36. "base_dir": os.path.join("VOCdevkit", "VOC2009"),
  37. },
  38. "2008": {
  39. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
  40. "filename": "VOCtrainval_11-May-2012.tar",
  41. "md5": "2629fa636546599198acfcfbfcf1904a",
  42. "base_dir": os.path.join("VOCdevkit", "VOC2008"),
  43. },
  44. "2007": {
  45. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
  46. "filename": "VOCtrainval_06-Nov-2007.tar",
  47. "md5": "c52e279531787c972589f7e41ab4ae64",
  48. "base_dir": os.path.join("VOCdevkit", "VOC2007"),
  49. },
  50. "2007-test": {
  51. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
  52. "filename": "VOCtest_06-Nov-2007.tar",
  53. "md5": "b6e924de25625d8de591ea690078ad9f",
  54. "base_dir": os.path.join("VOCdevkit", "VOC2007"),
  55. },
  56. }
  57. class _VOCBase(VisionDataset):
  58. _SPLITS_DIR: str
  59. _TARGET_DIR: str
  60. _TARGET_FILE_EXT: str
  61. def __init__(
  62. self,
  63. root: str,
  64. year: str = "2012",
  65. image_set: str = "train",
  66. download: bool = False,
  67. transform: Optional[Callable] = None,
  68. target_transform: Optional[Callable] = None,
  69. transforms: Optional[Callable] = None,
  70. ):
  71. super().__init__(root, transforms, transform, target_transform)
  72. if year == "2007-test":
  73. if image_set == "test":
  74. warnings.warn(
  75. "Accessing the test image set of the year 2007 with year='2007-test' is deprecated "
  76. "since 0.12 and will be removed in 0.14. "
  77. "Please use the combination year='2007' and image_set='test' instead."
  78. )
  79. year = "2007"
  80. else:
  81. raise ValueError(
  82. "In the test image set of the year 2007 only image_set='test' is allowed. "
  83. "For all other image sets use year='2007' instead."
  84. )
  85. self.year = year
  86. valid_image_sets = ["train", "trainval", "val"]
  87. if year == "2007":
  88. valid_image_sets.append("test")
  89. self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
  90. key = "2007-test" if year == "2007" and image_set == "test" else year
  91. dataset_year_dict = DATASET_YEAR_DICT[key]
  92. self.url = dataset_year_dict["url"]
  93. self.filename = dataset_year_dict["filename"]
  94. self.md5 = dataset_year_dict["md5"]
  95. base_dir = dataset_year_dict["base_dir"]
  96. voc_root = os.path.join(self.root, base_dir)
  97. if download:
  98. download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
  99. if not os.path.isdir(voc_root):
  100. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  101. splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
  102. split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
  103. with open(os.path.join(split_f)) as f:
  104. file_names = [x.strip() for x in f.readlines()]
  105. image_dir = os.path.join(voc_root, "JPEGImages")
  106. self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
  107. target_dir = os.path.join(voc_root, self._TARGET_DIR)
  108. self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
  109. assert len(self.images) == len(self.targets)
  110. def __len__(self) -> int:
  111. return len(self.images)
  112. class VOCSegmentation(_VOCBase):
  113. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
  114. Args:
  115. root (string): Root directory of the VOC Dataset.
  116. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
  117. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
  118. ``year=="2007"``, can also be ``"test"``.
  119. download (bool, optional): If true, downloads the dataset from the internet and
  120. puts it in root directory. If dataset is already downloaded, it is not
  121. downloaded again.
  122. transform (callable, optional): A function/transform that takes in an PIL image
  123. and returns a transformed version. E.g, ``transforms.RandomCrop``
  124. target_transform (callable, optional): A function/transform that takes in the
  125. target and transforms it.
  126. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  127. and returns a transformed version.
  128. """
  129. _SPLITS_DIR = "Segmentation"
  130. _TARGET_DIR = "SegmentationClass"
  131. _TARGET_FILE_EXT = ".png"
  132. @property
  133. def masks(self) -> List[str]:
  134. return self.targets
  135. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  136. """
  137. Args:
  138. index (int): Index
  139. Returns:
  140. tuple: (image, target) where target is the image segmentation.
  141. """
  142. img = Image.open(self.images[index]).convert("RGB")
  143. target = Image.open(self.masks[index])
  144. if self.transforms is not None:
  145. img, target = self.transforms(img, target)
  146. return img, target
  147. class VOCDetection(_VOCBase):
  148. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
  149. Args:
  150. root (string): Root directory of the VOC Dataset.
  151. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
  152. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
  153. ``year=="2007"``, can also be ``"test"``.
  154. download (bool, optional): If true, downloads the dataset from the internet and
  155. puts it in root directory. If dataset is already downloaded, it is not
  156. downloaded again.
  157. (default: alphabetic indexing of VOC's 20 classes).
  158. transform (callable, optional): A function/transform that takes in an PIL image
  159. and returns a transformed version. E.g, ``transforms.RandomCrop``
  160. target_transform (callable, required): A function/transform that takes in the
  161. target and transforms it.
  162. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  163. and returns a transformed version.
  164. """
  165. _SPLITS_DIR = "Main"
  166. _TARGET_DIR = "Annotations"
  167. _TARGET_FILE_EXT = ".xml"
  168. @property
  169. def annotations(self) -> List[str]:
  170. return self.targets
  171. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  172. """
  173. Args:
  174. index (int): Index
  175. Returns:
  176. tuple: (image, target) where target is a dictionary of the XML tree.
  177. """
  178. img = Image.open(self.images[index]).convert("RGB")
  179. target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
  180. if self.transforms is not None:
  181. img, target = self.transforms(img, target)
  182. return img, target
  183. @staticmethod
  184. def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
  185. voc_dict: Dict[str, Any] = {}
  186. children = list(node)
  187. if children:
  188. def_dic: Dict[str, Any] = collections.defaultdict(list)
  189. for dc in map(VOCDetection.parse_voc_xml, children):
  190. for ind, v in dc.items():
  191. def_dic[ind].append(v)
  192. if node.tag == "annotation":
  193. def_dic["object"] = [def_dic["object"]]
  194. voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
  195. if node.text:
  196. text = node.text.strip()
  197. if not children:
  198. voc_dict[node.tag] = text
  199. return voc_dict