| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- ## @package visualize
- # Module caffe2.python.visualize
- """Functions that could be used to visualize Tensors.
- This is adapted from the old-time iceberk package that Yangqing wrote... Oh gold
- memories. Before decaf and caffe. Why iceberk? Because I was at Berkeley,
- bears are vegetarian, and iceberg lettuce has layers of leaves.
- (This joke is so lame.)
- """
- import numpy as np
- from matplotlib import cm, pyplot
- def ChannelFirst(arr):
- """Convert a HWC array to CHW."""
- ndim = arr.ndim
- return arr.swapaxes(ndim - 1, ndim - 2).swapaxes(ndim - 2, ndim - 3)
- def ChannelLast(arr):
- """Convert a CHW array to HWC."""
- ndim = arr.ndim
- return arr.swapaxes(ndim - 3, ndim - 2).swapaxes(ndim - 2, ndim - 1)
- class PatchVisualizer(object):
- """PatchVisualizer visualizes patches.
- """
- def __init__(self, gap=1):
- self.gap = gap
- def ShowSingle(self, patch, cmap=None):
- """Visualizes one single patch.
- The input patch could be a vector (in which case we try to infer the shape
- of the patch), a 2-D matrix, or a 3-D matrix whose 3rd dimension has 3
- channels.
- """
- if len(patch.shape) == 1:
- patch = patch.reshape(self.get_patch_shape(patch))
- elif len(patch.shape) > 2 and patch.shape[2] != 3:
- raise ValueError("The input patch shape isn't correct.")
- # determine color
- if len(patch.shape) == 2 and cmap is None:
- cmap = cm.gray
- pyplot.imshow(patch, cmap=cmap)
- return patch
- def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean):
- """Visualize multiple patches.
- In the passed in patches matrix, each row is a patch, in the shape of either
- n*n, n*n*1 or n*n*3, either in a flattened format (so patches would be a
- 2-D array), or a multi-dimensional tensor. We will try our best to figure
- out automatically the patch size.
- """
- num_patches = patches.shape[0]
- if ncols is None:
- ncols = int(np.ceil(np.sqrt(num_patches)))
- nrows = int(np.ceil(num_patches / float(ncols)))
- if len(patches.shape) == 2:
- patches = patches.reshape(
- (patches.shape[0], ) + self.get_patch_shape(patches[0])
- )
- patch_size_expand = np.array(patches.shape[1:3]) + self.gap
- image_size = patch_size_expand * np.array([nrows, ncols]) - self.gap
- if len(patches.shape) == 4:
- if patches.shape[3] == 1:
- # gray patches
- patches = patches.reshape(patches.shape[:-1])
- image_shape = tuple(image_size)
- if cmap is None:
- cmap = cm.gray
- elif patches.shape[3] == 3:
- # color patches
- image_shape = tuple(image_size) + (3, )
- else:
- raise ValueError("The input patch shape isn't expected.")
- else:
- image_shape = tuple(image_size)
- if cmap is None:
- cmap = cm.gray
- image = np.ones(image_shape) * bg_func(patches)
- for pid in range(num_patches):
- row = pid // ncols * patch_size_expand[0]
- col = pid % ncols * patch_size_expand[1]
- image[row:row+patches.shape[1], col:col+patches.shape[2]] = \
- patches[pid]
- pyplot.imshow(image, cmap=cmap, interpolation='nearest')
- pyplot.axis('off')
- return image
- def ShowImages(self, patches, *args, **kwargs):
- """Similar to ShowMultiple, but always normalize the values between 0 and 1
- for better visualization of image-type data.
- """
- patches = patches - np.min(patches)
- patches /= np.max(patches) + np.finfo(np.float64).eps
- return self.ShowMultiple(patches, *args, **kwargs)
- def ShowChannels(self, patch, cmap=None, bg_func=np.mean):
- """ This function shows the channels of a patch.
- The incoming patch should have shape [w, h, num_channels], and each channel
- will be visualized as a separate gray patch.
- """
- if len(patch.shape) != 3:
- raise ValueError("The input patch shape isn't correct.")
- patch_reordered = np.swapaxes(patch.T, 1, 2)
- return self.ShowMultiple(patch_reordered, cmap=cmap, bg_func=bg_func)
- def get_patch_shape(self, patch):
- """Gets the shape of a single patch.
- Basically it tries to interpret the patch as a square, and also check if it
- is in color (3 channels)
- """
- edgeLen = np.sqrt(patch.size)
- if edgeLen != np.floor(edgeLen):
- # we are given color patches
- edgeLen = np.sqrt(patch.size / 3.)
- if edgeLen != np.floor(edgeLen):
- raise ValueError("I can't figure out the patch shape.")
- return (edgeLen, edgeLen, 3)
- else:
- edgeLen = int(edgeLen)
- return (edgeLen, edgeLen)
- _default_visualizer = PatchVisualizer()
- """Utility functions that directly point to functions in the default visualizer.
- These functions don't return anything, so you won't see annoying printouts of
- the visualized images. If you want to save the images for example, you should
- explicitly instantiate a patch visualizer, and call those functions.
- """
- class NHWC(object):
- @staticmethod
- def ShowSingle(*args, **kwargs):
- _default_visualizer.ShowSingle(*args, **kwargs)
- @staticmethod
- def ShowMultiple(*args, **kwargs):
- _default_visualizer.ShowMultiple(*args, **kwargs)
- @staticmethod
- def ShowImages(*args, **kwargs):
- _default_visualizer.ShowImages(*args, **kwargs)
- @staticmethod
- def ShowChannels(*args, **kwargs):
- _default_visualizer.ShowChannels(*args, **kwargs)
- class NCHW(object):
- @staticmethod
- def ShowSingle(patch, *args, **kwargs):
- _default_visualizer.ShowSingle(ChannelLast(patch), *args, **kwargs)
- @staticmethod
- def ShowMultiple(patch, *args, **kwargs):
- _default_visualizer.ShowMultiple(ChannelLast(patch), *args, **kwargs)
- @staticmethod
- def ShowImages(patch, *args, **kwargs):
- _default_visualizer.ShowImages(ChannelLast(patch), *args, **kwargs)
- @staticmethod
- def ShowChannels(patch, *args, **kwargs):
- _default_visualizer.ShowChannels(ChannelLast(patch), *args, **kwargs)
|