Source code for cntk.contrib.deeprl.agent.shared.preprocessing

# Copyright (c) Microsoft. All rights reserved.

# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
"""Base class for defining preprocessing, as well as two concrete examples."""

from abc import ABCMeta, abstractmethod
from collections import deque

import numpy as np
from PIL import Image


[docs]class Preprocessing(object): """Base class for defining preprocessing. All subclass constructors will take input_shape as the first argument. """ __metaclass__ = ABCMeta def __init__(self, input_shape): """Constructor for base Preprocessing class.""" self._input_shape = input_shape @abstractmethod
[docs] def output_shape(self): """Return shape of preprocessed observation.""" pass
@abstractmethod
[docs] def reset(self): """Reset preprocessing pipeline for new episode.""" pass
@abstractmethod
[docs] def preprocess(self, observation): """Return preprocessed observation.""" pass
[docs]class AtariPreprocessing(Preprocessing): """Preprocess screen images from Atari 2600 games. The image is represented by an array of shape (210, 160, 3). See https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf for more details. """ def __init__(self, input_shape, history_len=4): super(AtariPreprocessing, self).__init__(input_shape) self.__history_len = history_len self.__processed_image_seq = deque(maxlen=history_len) self.reset()
[docs] def output_shape(self): """Return shape of preprocessed Atari images.""" return (self.__history_len, 84, 84)
[docs] def reset(self): """Reset preprocessing pipeline for new episode.""" self.__previous_raw_image = np.zeros(self._input_shape, dtype=np.uint8) self.__processed_image_seq.clear() for i in range(self.__history_len): self.__processed_image_seq.append(np.zeros((84, 84)))
[docs] def preprocess(self, image): """Return preprocessed screen images from Atari 2600 games.""" if image.shape != self._input_shape: raise ValueError( 'Expecting image in shape {0} but get {1}\n'.format( self._input_shape, image.shape)) # Take the maximum value for each pixel over the current frame and the # previous one. im = Image.fromarray( np.maximum(image, self.__previous_raw_image), mode='RGB') # Extract luminance band. im = im.convert('YCbCr').split()[0] # Scale to 84 x 84 im = im.resize((84, 84), Image.BILINEAR) self.__processed_image_seq.append(np.array(im)) self.__previous_raw_image = image return np.stack(list(self.__processed_image_seq))
[docs]class SlidingWindow(Preprocessing): """Stack windowed inputs (x(t-m+1), ... x(t)).""" def __init__(self, input_shape, history_len=4, dtype=np.float32): super(SlidingWindow, self).__init__(input_shape) self.__dtype = dtype self.__history_len = history_len self.__history = deque(maxlen=history_len) self.reset()
[docs] def output_shape(self): """Return shape of preprocessed input.""" return (self.__history_len,) + self._input_shape
[docs] def reset(self): """Reset preprocessing pipeline for new episode.""" self.__history.clear() for i in range(self.__history_len): self.__history.append(np.zeros(self._input_shape, self.__dtype))
[docs] def preprocess(self, x): """Return preprocessed input x.""" if x.shape != self._input_shape: raise ValueError( 'Expecting input in shape {0} but get {1}\n'.format( self._input_shape, x.shape)) if x.dtype != self.__dtype: raise ValueError( 'Expecting input in dtype {0} but get {1}\n'.format( self.__dtype, x.dtype)) self.__history.append(x) return np.stack(list(self.__history))