video_reader.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from typing import Any, Dict, Iterator
  2. import torch
  3. from ..utils import _log_api_usage_once
  4. try:
  5. from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
  6. except ModuleNotFoundError:
  7. _HAS_GPU_VIDEO_DECODER = False
  8. from ._video_opt import (
  9. _HAS_VIDEO_OPT,
  10. )
  11. if _HAS_VIDEO_OPT:
  12. def _has_video_opt() -> bool:
  13. return True
  14. else:
  15. def _has_video_opt() -> bool:
  16. return False
  17. class VideoReader:
  18. """
  19. Fine-grained video-reading API.
  20. Supports frame-by-frame reading of various streams from a single video
  21. container.
  22. .. betastatus:: VideoReader class
  23. Example:
  24. The following examples creates a :mod:`VideoReader` object, seeks into 2s
  25. point, and returns a single frame::
  26. import torchvision
  27. video_path = "path_to_a_test_video"
  28. reader = torchvision.io.VideoReader(video_path, "video")
  29. reader.seek(2.0)
  30. frame = next(reader)
  31. :mod:`VideoReader` implements the iterable API, which makes it suitable to
  32. using it in conjunction with :mod:`itertools` for more advanced reading.
  33. As such, we can use a :mod:`VideoReader` instance inside for loops::
  34. reader.seek(2)
  35. for frame in reader:
  36. frames.append(frame['data'])
  37. # additionally, `seek` implements a fluent API, so we can do
  38. for frame in reader.seek(2):
  39. frames.append(frame['data'])
  40. With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
  41. following code::
  42. for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
  43. frames.append(frame['data'])
  44. and similarly, reading 10 frames after the 2s timestamp can be achieved
  45. as follows::
  46. for frame in itertools.islice(reader.seek(2), 10):
  47. frames.append(frame['data'])
  48. .. note::
  49. Each stream descriptor consists of two parts: stream type (e.g. 'video') and
  50. a unique stream id (which are determined by the video encoding).
  51. In this way, if the video contaner contains multiple
  52. streams of the same type, users can acces the one they want.
  53. If only stream type is passed, the decoder auto-detects first stream of that type.
  54. Args:
  55. path (string): Path to the video file in supported format
  56. stream (string, optional): descriptor of the required stream, followed by the stream id,
  57. in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
  58. Currently available options include ``['video', 'audio']``
  59. num_threads (int, optional): number of threads used by the codec to decode video.
  60. Default value (0) enables multithreading with codec-dependent heuristic. The performance
  61. will depend on the version of FFMPEG codecs supported.
  62. device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
  63. To use GPU decoding, pass ``device="cuda"``.
  64. """
  65. def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None:
  66. _log_api_usage_once(self)
  67. self.is_cuda = False
  68. device = torch.device(device)
  69. if device.type == "cuda":
  70. if not _HAS_GPU_VIDEO_DECODER:
  71. raise RuntimeError("Not compiled with GPU decoder support.")
  72. self.is_cuda = True
  73. self._c = torch.classes.torchvision.GPUDecoder(path, device)
  74. return
  75. if not _has_video_opt():
  76. raise RuntimeError(
  77. "Not compiled with video_reader support, "
  78. + "to enable video_reader support, please install "
  79. + "ffmpeg (version 4.2 is currently supported) and "
  80. + "build torchvision from source."
  81. )
  82. self._c = torch.classes.torchvision.Video(path, stream, num_threads)
  83. def __next__(self) -> Dict[str, Any]:
  84. """Decodes and returns the next frame of the current stream.
  85. Frames are encoded as a dict with mandatory
  86. data and pts fields, where data is a tensor, and pts is a
  87. presentation timestamp of the frame expressed in seconds
  88. as a float.
  89. Returns:
  90. (dict): a dictionary and containing decoded frame (``data``)
  91. and corresponding timestamp (``pts``) in seconds
  92. """
  93. if self.is_cuda:
  94. frame = self._c.next()
  95. if frame.numel() == 0:
  96. raise StopIteration
  97. return {"data": frame}
  98. frame, pts = self._c.next()
  99. if frame.numel() == 0:
  100. raise StopIteration
  101. return {"data": frame, "pts": pts}
  102. def __iter__(self) -> Iterator[Dict[str, Any]]:
  103. return self
  104. def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
  105. """Seek within current stream.
  106. Args:
  107. time_s (float): seek time in seconds
  108. keyframes_only (bool): allow to seek only to keyframes
  109. .. note::
  110. Current implementation is the so-called precise seek. This
  111. means following seek, call to :mod:`next()` will return the
  112. frame with the exact timestamp if it exists or
  113. the first frame with timestamp larger than ``time_s``.
  114. """
  115. self._c.seek(time_s, keyframes_only)
  116. return self
  117. def get_metadata(self) -> Dict[str, Any]:
  118. """Returns video metadata
  119. Returns:
  120. (dict): dictionary containing duration and frame rate for every stream
  121. """
  122. return self._c.get_metadata()
  123. def set_current_stream(self, stream: str) -> bool:
  124. """Set current stream.
  125. Explicitly define the stream we are operating on.
  126. Args:
  127. stream (string): descriptor of the required stream. Defaults to ``"video:0"``
  128. Currently available stream types include ``['video', 'audio']``.
  129. Each descriptor consists of two parts: stream type (e.g. 'video') and
  130. a unique stream id (which are determined by video encoding).
  131. In this way, if the video contaner contains multiple
  132. streams of the same type, users can acces the one they want.
  133. If only stream type is passed, the decoder auto-detects first stream
  134. of that type and returns it.
  135. Returns:
  136. (bool): True on succes, False otherwise
  137. """
  138. if self.is_cuda:
  139. print("GPU decoding only works with video stream.")
  140. return self._c.set_current_stream(stream)