Skip to content

odak.learn.perception

odak.learn.perception

Defines a number of different perceptual loss functions, which can be used to optimise images where gaze location is known.

BlurLoss

BlurLoss implements two different blur losses. When blur_source is set to False, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

When blur_source is set to True, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

The interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/blur_loss.py
class BlurLoss():
    """ 

    `BlurLoss` implements two different blur losses. When `blur_source` is set to `False`, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

    When `blur_source` is set to `True`, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

    The interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device("cpu"),
                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
        """
        Parameters
        ----------

        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        blur_source             : bool
                                    If true, blurs the source image as well as the target before computing the loss.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.device = device
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.mode = mode
        self.blur = None
        self.loss_func = torch.nn.MSELoss()
        self.blur_source = blur_source
        self.equi = equi

    def blur_image(self, image, gaze):
        if self.blur is None:
            self.blur = RadiallyVaryingBlur()
        return self.blur.blur(image, self.alpha, self.real_image_width, self.real_viewing_distance, gaze, self.mode, self.equi)

    def __call__(self, image, target, gaze=[0.5, 0.5]):
        """ 
        Calculates the Blur Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        gaze                : list
                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("BlurLoss", image, target)
        blurred_target = self.blur_image(target, gaze)
        if self.blur_source:
            blurred_image = self.blur_image(image, gaze)
            return self.loss_func(blurred_image, blurred_target)
        else:
            return self.loss_func(image, blurred_target)

    def to(self, device):
        self.device = device
        return self

__call__(image, target, gaze=[0.5, 0.5])

Calculates the Blur Loss.

Parameters:

  • image
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/blur_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):
    """ 
    Calculates the Blur Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    gaze                : list
                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("BlurLoss", image, target)
    blurred_target = self.blur_image(target, gaze)
    if self.blur_source:
        blurred_image = self.blur_image(image, gaze)
        return self.loss_func(blurred_image, blurred_target)
    else:
        return self.loss_func(image, blurred_target)

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', blur_source=False, equi=False)

Parameters:

  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • blur_source
                        If true, blurs the source image as well as the target before computing the loss.
    
  • equi
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/blur_loss.py
def __init__(self, device=torch.device("cpu"),
             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
    """
    Parameters
    ----------

    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    blur_source             : bool
                                If true, blurs the source image as well as the target before computing the loss.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.device = device
    self.alpha = alpha
    self.real_image_width = real_image_width
    self.real_viewing_distance = real_viewing_distance
    self.mode = mode
    self.blur = None
    self.loss_func = torch.nn.MSELoss()
    self.blur_source = blur_source
    self.equi = equi

display_color_hvs

Source code in odak/learn/perception/color_conversion.py
class display_color_hvs():

    def __init__(self, resolution = [1920, 1080],
                 distance_from_screen = 800,
                 pixel_pitch = 0.311,
                 read_spectrum = 'tensor',
                 primaries_spectrum = torch.rand(3, 301),
                 device = torch.device('cpu')):
        '''
        Parameters
        ----------
        resolution                  : list
                                      Resolution of the display in pixels.
        distance_from_screen        : int
                                      Distance from the screen in mm.
        pixel_pitch                 : float
                                      Pixel pitch of the display in mm.
        read_spectrum               : str
                                      Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
        device                      : torch.device
                                      Device to run the code on. Default is None which means the code will run on CPU.

        '''
        self.device = device
        self.read_spectrum = read_spectrum
        self.primaries_spectrum = primaries_spectrum.to(self.device)
        self.resolution = resolution
        self.distance_from_screen = distance_from_screen
        self.pixel_pitch = pixel_pitch
        self.l_normalised, self.m_normalised, self.s_normalised = self.initialize_cones_normalised()
        self.lms_tensor = self.construct_matrix_lms(
                                                    self.l_normalised,
                                                    self.m_normalised,
                                                    self.s_normalised
                                                   )   
        self.primaries_tensor = self.construct_matrix_primaries(
                                                    self.l_normalised,
                                                    self.m_normalised,
                                                    self.s_normalised
                                                   )   
        return


    def __call__(self, input_image, ground_truth, gaze=None):
        """
        Evaluating an input image against a target ground truth image for a given gaze of a viewer.
        """
        lms_image_second = self.primaries_to_lms(input_image.to(self.device))
        lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
        lms_image_third = self.second_to_third_stage(lms_image_second)
        lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
        loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
        return loss_metamer_color


    def initialize_cones_normalised(self):
        """
        Internal function to initialize normalised L,M,S cones as normal distribution with given sigma, and mu values. 

        Returns
        -------
        l_cone_n                     : torch.tensor
                                       Normalised L cone distribution.
        m_cone_n                     : torch.tensor
                                       Normalised M cone distribution.
        s_cone_n                     : torch.tensor
                                       Normalised S cone distribution.
        """
        wavelength_range = np_cpu.linspace(400, 700, num=301)
        dist_l = [1 / (32.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                               567.5)**2 / (2 * 32.5**2)) for i in range(len(wavelength_range))]
        dist_m = [1 / (27.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                               545.0)**2 / (2 * 27.5**2)) for i in range(len(wavelength_range))]
        dist_s = [1 / (17.0 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                               447.5)**2 / (2 * 17.0**2)) for i in range(len(wavelength_range))]

        l_cone_n = torch.from_numpy(dist_l/max(dist_l))
        m_cone_n = torch.from_numpy(dist_m/max(dist_m))
        s_cone_n = torch.from_numpy(dist_s/max(dist_s))
        return l_cone_n.to(self.device), m_cone_n.to(self.device), s_cone_n.to(self.device)


    def initialize_rgb_backlight_spectrum(self):
        """
        Internal function to initialize baclight spectrum for color primaries. 

        Returns
        -------
        red_spectrum                 : torch.tensor
                                       Normalised backlight spectrum for red color primary.
        green_spectrum               : torch.tensor
                                       Normalised backlight spectrum for green color primary.
        blue_spectrum                : torch.tensor
                                       Normalised backlight spectrum for blue color primary.
        """
        wavelength_range = np_cpu.linspace(400, 700, num=301)
        red_spectrum = [1 / (14.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
            wavelength_range[i] - 650)**2 / (2 * 14.5**2)) for i in range(len(wavelength_range))]
        green_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
            wavelength_range[i] - 550)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]
        blue_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
            wavelength_range[i] - 450)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]

        red_spectrum = torch.from_numpy(
            red_spectrum / max(red_spectrum)) * 1.0
        green_spectrum = torch.from_numpy(
            green_spectrum / max(green_spectrum)) * 1.0
        blue_spectrum = torch.from_numpy(
            blue_spectrum / max(blue_spectrum)) * 1.0

        return red_spectrum.to(self.device), green_spectrum.to(self.device), blue_spectrum.to(self.device)


    def initialize_random_spectrum_normalised(self, dataset):
        """
        Initialize normalised light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 

        Parameters
        ----------
        dataset                                : torch.tensor 
                                                 spectrum value against wavelength 
        """
        if (type(dataset).__module__) == "torch":
            dataset = dataset.numpy()
        if dataset.shape[0] > dataset.shape[1]:
            dataset = np_cpu.swapaxes(dataset, 0, 1)
        x_spectrum = np_cpu.linspace(400, 700, num=301)
        y_spectrum = np_cpu.interp(x_spectrum, dataset[0], dataset[1])
        x_spectrum = torch.from_numpy(x_spectrum) - 550
        y_spectrum = torch.from_numpy(y_spectrum)
        max_spectrum = torch.max(y_spectrum)
        y_spectrum /= max_spectrum

        def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
            torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))

        def function(x, weights): return gaussian(
            x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
        weights = torch.tensor(
            [1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad=True)
        optimizer = torch.optim.LBFGS(
            [weights], max_iter = 1000, lr = 0.1, line_search_fn=None)

        def closure():
            optimizer.zero_grad()
            output = function(x_spectrum, weights)
            loss = F.mse_loss(output, y_spectrum)
            loss.backward()
            return loss
        optimizer.step(closure)
        spectrum = function(x_spectrum, weights)
        return spectrum.detach().to(self.device)


    def display_spectrum_response(wavelength, function):
        """
        Internal function to provide light spectrum response at particular wavelength

        Parameters
        ----------
        wavelength                          : torch.tensor
                                              Wavelength in nm [400...700]
        function                            : torch.tensor
                                              Display light spectrum distribution function

        Returns
        -------
        ligth_response_dict                  : float
                                               Display light spectrum response value
        """
        wavelength = int(round(wavelength, 0))
        if wavelength >= 400 and wavelength <= 700:
            return function[wavelength - 400].item()
        elif wavelength < 400:
            return function[0].item()
        else:
            return function[300].item()


    def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
        """
        Internal function to calculate cone response at particular light spectrum. 

        Parameters
        ----------
        cone_spectrum                         : torch.tensor
                                                Spectrum, Wavelength [2,300] tensor 
        light_spectrum                        : torch.tensor
                                                Spectrum, Wavelength [2,300] tensor 


        Returns
        -------
        response_to_spectrum                  : float
                                                Response of cone to light spectrum [1x1] 
        """
        response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
        response_to_spectrum = torch.sum(response_to_spectrum)
        return response_to_spectrum.item()


    def construct_matrix_lms(self, l_response, m_response, s_response):
        '''
        Internal function to calculate cone  response at particular light spectrum. 

        Parameters
        ----------
        l_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        m_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        s_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)



        Returns
        -------
        lms_image_tensor                      : torch.tensor
                                                3x3 LMSrgb tensor

        '''
        if self.read_spectrum == 'tensor':
            logging.warning('Tensor primary spectrum is used')
            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
        else:
            logging.warning("No Spectrum data is provided")

        self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
        for i in range(self.primaries_spectrum.shape[0]):
            self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response,
                                                                   self.primaries_spectrum[i]
                                                                   ) 
        return self.lms_tensor    

    def construct_matrix_primaries(self, l_response, m_response, s_response):
        '''
        Internal function to calculate cone  response at particular light spectrum. 

        Parameters
        ----------
        l_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        m_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        s_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)



        Returns
        -------
        lms_image_tensor                      : torch.tensor
                                                3x3 LMSrgb tensor

        '''
        if self.read_spectrum == 'tensor':
            logging.warning('Tensor primary spectrum is used')
            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
        else:
            logging.warning("No Spectrum data is provided")

        self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
        for i in range(self.primaries_spectrum.shape[0]):
            self.primaries_tensor[0, i] = self.cone_response_to_spectrum(l_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.primaries_tensor[1, i] = self.cone_response_to_spectrum(m_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.primaries_tensor[2, i] = self.cone_response_to_spectrum(s_response,
                                                                   self.primaries_spectrum[i]
                                                                   ) 
        return self.primaries_tensor    


    def primaries_to_lms(self, primaries):
        """
        Internal function to convert primaries space to LMS space 

        Parameters
        ----------
        primaries                              : torch.tensor
                                                 Primaries data to be transformed to LMS space [BxPHxW]


        Returns
        -------
        lms_color                              : torch.tensor
                                                 LMS data transformed from Primaries space [BxPxHxW]
        """                
        primaries = primaries.permute(0, 2, 3, 1).to(self.device)
        primaries_flatten = torch.flatten(primaries, start_dim = 1, end_dim = 2)
        unflatten = torch.nn.Unflatten(1, (primaries.size(1), primaries.size(2)))
        converted_unflatten = torch.matmul(primaries_flatten.double(), self.lms_tensor.double())
        lms_color = unflatten(converted_unflatten)        
        lms_color = lms_color.permute(0, 3, 1, 2)
        return lms_color


    def lms_to_primaries(self, lms_color_tensor):
        """
        Internal function to convert LMS image to primaries space

        Parameters
        ----------
        lms_color_tensor                        : torch.tensor
                                                  LMS data to be transformed to primaries space [Bx3xHxW]


        Returns
        -------
        primaries                              : torch.tensor
                                               : Primaries data transformed from LMS space [BxPxHxW]
        """
        lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
        lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
        unflatten = torch.nn.Unflatten(
            0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
        converted_unflatten = torch.matmul(
            lms_color_flatten.double(), self.lms_tensor.pinverse().double())
        primaries = unflatten(converted_unflatten)     
        primaries = primaries.permute(0, 3, 1, 2)   
        return primaries


    def second_to_third_stage(self, lms_image):
        '''
        This function turns second stage [L,M,S] values into third stage [(L+S)-M, M-(L+S), (M+S)-L]
        Equations are taken from Schmidt et al "Neurobiological hypothesis of color appearance and hue perception" 2014

        Parameters
        ----------
        lms_image                             : torch.tensor
                                                 Image data at LMS space (second stage)

        Returns
        -------
        third_stage                            : torch.tensor
                                                 Image data at LMS space (third stage)

        '''
        lms_image = lms_image.permute(0,2,3,1)
        third_stage = torch.zeros(lms_image.shape[0],
            lms_image.shape[1], lms_image.shape[2], 3).to(self.device)
        third_stage[:, :, :, 0] = (lms_image[:, :, :, 1] +
                                lms_image[:, :, :, 2]) - lms_image[:, :, :, 0]
        third_stage[:, :, :, 1] = (lms_image[:, :, :, 0] +
                                lms_image[:, :, :, 2]) - lms_image[:, :, :, 1]
        third_stage[:, :, :, 2] = torch.sum(lms_image, dim=3) / 3.
        third_stage = third_stage.permute(0, 3, 1, 2)
        return third_stage


    def to(self, device):
        """
        Utilization function for setting the device.
        Parameters
        ----------
        device       : torch.device
                       Device to be used (e.g., CPU, Cuda, OpenCL).
        """
        self.device = device
        return self

__call__(input_image, ground_truth, gaze=None)

Evaluating an input image against a target ground truth image for a given gaze of a viewer.

Source code in odak/learn/perception/color_conversion.py
def __call__(self, input_image, ground_truth, gaze=None):
    """
    Evaluating an input image against a target ground truth image for a given gaze of a viewer.
    """
    lms_image_second = self.primaries_to_lms(input_image.to(self.device))
    lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
    lms_image_third = self.second_to_third_stage(lms_image_second)
    lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
    loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
    return loss_metamer_color

__init__(resolution=[1920, 1080], distance_from_screen=800, pixel_pitch=0.311, read_spectrum='tensor', primaries_spectrum=torch.rand(3, 301), device=torch.device('cpu'))

Parameters:

  • resolution
                          Resolution of the display in pixels.
    
  • distance_from_screen
                          Distance from the screen in mm.
    
  • pixel_pitch
                          Pixel pitch of the display in mm.
    
  • read_spectrum
                          Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    
  • device
                          Device to run the code on. Default is None which means the code will run on CPU.
    
Source code in odak/learn/perception/color_conversion.py
def __init__(self, resolution = [1920, 1080],
             distance_from_screen = 800,
             pixel_pitch = 0.311,
             read_spectrum = 'tensor',
             primaries_spectrum = torch.rand(3, 301),
             device = torch.device('cpu')):
    '''
    Parameters
    ----------
    resolution                  : list
                                  Resolution of the display in pixels.
    distance_from_screen        : int
                                  Distance from the screen in mm.
    pixel_pitch                 : float
                                  Pixel pitch of the display in mm.
    read_spectrum               : str
                                  Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    device                      : torch.device
                                  Device to run the code on. Default is None which means the code will run on CPU.

    '''
    self.device = device
    self.read_spectrum = read_spectrum
    self.primaries_spectrum = primaries_spectrum.to(self.device)
    self.resolution = resolution
    self.distance_from_screen = distance_from_screen
    self.pixel_pitch = pixel_pitch
    self.l_normalised, self.m_normalised, self.s_normalised = self.initialize_cones_normalised()
    self.lms_tensor = self.construct_matrix_lms(
                                                self.l_normalised,
                                                self.m_normalised,
                                                self.s_normalised
                                               )   
    self.primaries_tensor = self.construct_matrix_primaries(
                                                self.l_normalised,
                                                self.m_normalised,
                                                self.s_normalised
                                               )   
    return

cone_response_to_spectrum(cone_spectrum, light_spectrum)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • cone_spectrum
                                    Spectrum, Wavelength [2,300] tensor
    
  • light_spectrum
                                    Spectrum, Wavelength [2,300] tensor
    

Returns:

  • response_to_spectrum ( float ) –

    Response of cone to light spectrum [1x1]

Source code in odak/learn/perception/color_conversion.py
def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
    """
    Internal function to calculate cone response at particular light spectrum. 

    Parameters
    ----------
    cone_spectrum                         : torch.tensor
                                            Spectrum, Wavelength [2,300] tensor 
    light_spectrum                        : torch.tensor
                                            Spectrum, Wavelength [2,300] tensor 


    Returns
    -------
    response_to_spectrum                  : float
                                            Response of cone to light spectrum [1x1] 
    """
    response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
    response_to_spectrum = torch.sum(response_to_spectrum)
    return response_to_spectrum.item()

construct_matrix_lms(l_response, m_response, s_response)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • m_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • s_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    

Returns:

  • lms_image_tensor ( tensor ) –

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_lms(self, l_response, m_response, s_response):
    '''
    Internal function to calculate cone  response at particular light spectrum. 

    Parameters
    ----------
    l_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    m_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    s_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)



    Returns
    -------
    lms_image_tensor                      : torch.tensor
                                            3x3 LMSrgb tensor

    '''
    if self.read_spectrum == 'tensor':
        logging.warning('Tensor primary spectrum is used')
        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
    else:
        logging.warning("No Spectrum data is provided")

    self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
    for i in range(self.primaries_spectrum.shape[0]):
        self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response,
                                                               self.primaries_spectrum[i]
                                                               ) 
    return self.lms_tensor    

construct_matrix_primaries(l_response, m_response, s_response)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • m_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • s_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    

Returns:

  • lms_image_tensor ( tensor ) –

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_primaries(self, l_response, m_response, s_response):
    '''
    Internal function to calculate cone  response at particular light spectrum. 

    Parameters
    ----------
    l_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    m_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    s_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)



    Returns
    -------
    lms_image_tensor                      : torch.tensor
                                            3x3 LMSrgb tensor

    '''
    if self.read_spectrum == 'tensor':
        logging.warning('Tensor primary spectrum is used')
        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
    else:
        logging.warning("No Spectrum data is provided")

    self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
    for i in range(self.primaries_spectrum.shape[0]):
        self.primaries_tensor[0, i] = self.cone_response_to_spectrum(l_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.primaries_tensor[1, i] = self.cone_response_to_spectrum(m_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.primaries_tensor[2, i] = self.cone_response_to_spectrum(s_response,
                                                               self.primaries_spectrum[i]
                                                               ) 
    return self.primaries_tensor    

display_spectrum_response(wavelength, function)

Internal function to provide light spectrum response at particular wavelength

Parameters:

  • wavelength
                                  Wavelength in nm [400...700]
    
  • function
                                  Display light spectrum distribution function
    

Returns:

  • ligth_response_dict ( float ) –

    Display light spectrum response value

Source code in odak/learn/perception/color_conversion.py
def display_spectrum_response(wavelength, function):
    """
    Internal function to provide light spectrum response at particular wavelength

    Parameters
    ----------
    wavelength                          : torch.tensor
                                          Wavelength in nm [400...700]
    function                            : torch.tensor
                                          Display light spectrum distribution function

    Returns
    -------
    ligth_response_dict                  : float
                                           Display light spectrum response value
    """
    wavelength = int(round(wavelength, 0))
    if wavelength >= 400 and wavelength <= 700:
        return function[wavelength - 400].item()
    elif wavelength < 400:
        return function[0].item()
    else:
        return function[300].item()

initialize_cones_normalised()

Internal function to initialize normalised L,M,S cones as normal distribution with given sigma, and mu values.

Returns:

  • l_cone_n ( tensor ) –

    Normalised L cone distribution.

  • m_cone_n ( tensor ) –

    Normalised M cone distribution.

  • s_cone_n ( tensor ) –

    Normalised S cone distribution.

Source code in odak/learn/perception/color_conversion.py
def initialize_cones_normalised(self):
    """
    Internal function to initialize normalised L,M,S cones as normal distribution with given sigma, and mu values. 

    Returns
    -------
    l_cone_n                     : torch.tensor
                                   Normalised L cone distribution.
    m_cone_n                     : torch.tensor
                                   Normalised M cone distribution.
    s_cone_n                     : torch.tensor
                                   Normalised S cone distribution.
    """
    wavelength_range = np_cpu.linspace(400, 700, num=301)
    dist_l = [1 / (32.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                           567.5)**2 / (2 * 32.5**2)) for i in range(len(wavelength_range))]
    dist_m = [1 / (27.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                           545.0)**2 / (2 * 27.5**2)) for i in range(len(wavelength_range))]
    dist_s = [1 / (17.0 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                           447.5)**2 / (2 * 17.0**2)) for i in range(len(wavelength_range))]

    l_cone_n = torch.from_numpy(dist_l/max(dist_l))
    m_cone_n = torch.from_numpy(dist_m/max(dist_m))
    s_cone_n = torch.from_numpy(dist_s/max(dist_s))
    return l_cone_n.to(self.device), m_cone_n.to(self.device), s_cone_n.to(self.device)

initialize_random_spectrum_normalised(dataset)

Initialize normalised light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS].

Parameters:

  • dataset
                                     spectrum value against wavelength
    
Source code in odak/learn/perception/color_conversion.py
def initialize_random_spectrum_normalised(self, dataset):
    """
    Initialize normalised light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 

    Parameters
    ----------
    dataset                                : torch.tensor 
                                             spectrum value against wavelength 
    """
    if (type(dataset).__module__) == "torch":
        dataset = dataset.numpy()
    if dataset.shape[0] > dataset.shape[1]:
        dataset = np_cpu.swapaxes(dataset, 0, 1)
    x_spectrum = np_cpu.linspace(400, 700, num=301)
    y_spectrum = np_cpu.interp(x_spectrum, dataset[0], dataset[1])
    x_spectrum = torch.from_numpy(x_spectrum) - 550
    y_spectrum = torch.from_numpy(y_spectrum)
    max_spectrum = torch.max(y_spectrum)
    y_spectrum /= max_spectrum

    def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
        torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))

    def function(x, weights): return gaussian(
        x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
    weights = torch.tensor(
        [1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad=True)
    optimizer = torch.optim.LBFGS(
        [weights], max_iter = 1000, lr = 0.1, line_search_fn=None)

    def closure():
        optimizer.zero_grad()
        output = function(x_spectrum, weights)
        loss = F.mse_loss(output, y_spectrum)
        loss.backward()
        return loss
    optimizer.step(closure)
    spectrum = function(x_spectrum, weights)
    return spectrum.detach().to(self.device)

initialize_rgb_backlight_spectrum()

Internal function to initialize baclight spectrum for color primaries.

Returns:

  • red_spectrum ( tensor ) –

    Normalised backlight spectrum for red color primary.

  • green_spectrum ( tensor ) –

    Normalised backlight spectrum for green color primary.

  • blue_spectrum ( tensor ) –

    Normalised backlight spectrum for blue color primary.

Source code in odak/learn/perception/color_conversion.py
def initialize_rgb_backlight_spectrum(self):
    """
    Internal function to initialize baclight spectrum for color primaries. 

    Returns
    -------
    red_spectrum                 : torch.tensor
                                   Normalised backlight spectrum for red color primary.
    green_spectrum               : torch.tensor
                                   Normalised backlight spectrum for green color primary.
    blue_spectrum                : torch.tensor
                                   Normalised backlight spectrum for blue color primary.
    """
    wavelength_range = np_cpu.linspace(400, 700, num=301)
    red_spectrum = [1 / (14.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
        wavelength_range[i] - 650)**2 / (2 * 14.5**2)) for i in range(len(wavelength_range))]
    green_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
        wavelength_range[i] - 550)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]
    blue_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
        wavelength_range[i] - 450)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]

    red_spectrum = torch.from_numpy(
        red_spectrum / max(red_spectrum)) * 1.0
    green_spectrum = torch.from_numpy(
        green_spectrum / max(green_spectrum)) * 1.0
    blue_spectrum = torch.from_numpy(
        blue_spectrum / max(blue_spectrum)) * 1.0

    return red_spectrum.to(self.device), green_spectrum.to(self.device), blue_spectrum.to(self.device)

lms_to_primaries(lms_color_tensor)

Internal function to convert LMS image to primaries space

Parameters:

  • lms_color_tensor
                                      LMS data to be transformed to primaries space [Bx3xHxW]
    

Returns:

  • primaries ( tensor ) –

    : Primaries data transformed from LMS space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def lms_to_primaries(self, lms_color_tensor):
    """
    Internal function to convert LMS image to primaries space

    Parameters
    ----------
    lms_color_tensor                        : torch.tensor
                                              LMS data to be transformed to primaries space [Bx3xHxW]


    Returns
    -------
    primaries                              : torch.tensor
                                           : Primaries data transformed from LMS space [BxPxHxW]
    """
    lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
    lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
    unflatten = torch.nn.Unflatten(
        0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
    converted_unflatten = torch.matmul(
        lms_color_flatten.double(), self.lms_tensor.pinverse().double())
    primaries = unflatten(converted_unflatten)     
    primaries = primaries.permute(0, 3, 1, 2)   
    return primaries

primaries_to_lms(primaries)

Internal function to convert primaries space to LMS space

Parameters:

  • primaries
                                     Primaries data to be transformed to LMS space [BxPHxW]
    

Returns:

  • lms_color ( tensor ) –

    LMS data transformed from Primaries space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def primaries_to_lms(self, primaries):
    """
    Internal function to convert primaries space to LMS space 

    Parameters
    ----------
    primaries                              : torch.tensor
                                             Primaries data to be transformed to LMS space [BxPHxW]


    Returns
    -------
    lms_color                              : torch.tensor
                                             LMS data transformed from Primaries space [BxPxHxW]
    """                
    primaries = primaries.permute(0, 2, 3, 1).to(self.device)
    primaries_flatten = torch.flatten(primaries, start_dim = 1, end_dim = 2)
    unflatten = torch.nn.Unflatten(1, (primaries.size(1), primaries.size(2)))
    converted_unflatten = torch.matmul(primaries_flatten.double(), self.lms_tensor.double())
    lms_color = unflatten(converted_unflatten)        
    lms_color = lms_color.permute(0, 3, 1, 2)
    return lms_color

second_to_third_stage(lms_image)

This function turns second stage [L,M,S] values into third stage [(L+S)-M, M-(L+S), (M+S)-L] Equations are taken from Schmidt et al "Neurobiological hypothesis of color appearance and hue perception" 2014

Parameters:

  • lms_image
                                     Image data at LMS space (second stage)
    

Returns:

  • third_stage ( tensor ) –

    Image data at LMS space (third stage)

Source code in odak/learn/perception/color_conversion.py
def second_to_third_stage(self, lms_image):
    '''
    This function turns second stage [L,M,S] values into third stage [(L+S)-M, M-(L+S), (M+S)-L]
    Equations are taken from Schmidt et al "Neurobiological hypothesis of color appearance and hue perception" 2014

    Parameters
    ----------
    lms_image                             : torch.tensor
                                             Image data at LMS space (second stage)

    Returns
    -------
    third_stage                            : torch.tensor
                                             Image data at LMS space (third stage)

    '''
    lms_image = lms_image.permute(0,2,3,1)
    third_stage = torch.zeros(lms_image.shape[0],
        lms_image.shape[1], lms_image.shape[2], 3).to(self.device)
    third_stage[:, :, :, 0] = (lms_image[:, :, :, 1] +
                            lms_image[:, :, :, 2]) - lms_image[:, :, :, 0]
    third_stage[:, :, :, 1] = (lms_image[:, :, :, 0] +
                            lms_image[:, :, :, 2]) - lms_image[:, :, :, 1]
    third_stage[:, :, :, 2] = torch.sum(lms_image, dim=3) / 3.
    third_stage = third_stage.permute(0, 3, 1, 2)
    return third_stage

to(device)

Utilization function for setting the device.

Parameters:

  • device
           Device to be used (e.g., CPU, Cuda, OpenCL).
    
Source code in odak/learn/perception/color_conversion.py
def to(self, device):
    """
    Utilization function for setting the device.
    Parameters
    ----------
    device       : torch.device
                   Device to be used (e.g., CPU, Cuda, OpenCL).
    """
    self.device = device
    return self

color_map(input_image, target_image, model='Lab Stats')

Internal function to map the color of an image to another image. Reference: Color transfer between images, Reinhard et al., 2001.

Parameters:

  • input_image
                  Input image in RGB color space [3 x m x n].
    
  • target_image

Returns:

  • mapped_image ( Tensor ) –

    Input image with the color the distribution of the target image [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def color_map(input_image, target_image, model = 'Lab Stats'):
    """
    Internal function to map the color of an image to another image.
    Reference: Color transfer between images, Reinhard et al., 2001.

    Parameters
    ----------
    input_image         : torch.Tensor
                          Input image in RGB color space [3 x m x n].
    target_image        : torch.Tensor

    Returns
    -------
    mapped_image           : torch.Tensor
                             Input image with the color the distribution of the target image [3 x m x n].
    """
    if model == 'Lab Stats':
        lab_input = srgb_to_lab(input_image)
        lab_target = srgb_to_lab(target_image)
        input_mean_L = torch.mean(lab_input[0, :, :])
        input_mean_a = torch.mean(lab_input[1, :, :])
        input_mean_b = torch.mean(lab_input[2, :, :])
        input_std_L = torch.std(lab_input[0, :, :])
        input_std_a = torch.std(lab_input[1, :, :])
        input_std_b = torch.std(lab_input[2, :, :])
        target_mean_L = torch.mean(lab_target[0, :, :])
        target_mean_a = torch.mean(lab_target[1, :, :])
        target_mean_b = torch.mean(lab_target[2, :, :])
        target_std_L = torch.std(lab_target[0, :, :])
        target_std_a = torch.std(lab_target[1, :, :])
        target_std_b = torch.std(lab_target[2, :, :])
        lab_input[0, :, :] = (lab_input[0, :, :] - input_mean_L) * (target_std_L / input_std_L) + target_mean_L
        lab_input[1, :, :] = (lab_input[1, :, :] - input_mean_a) * (target_std_a / input_std_a) + target_mean_a
        lab_input[2, :, :] = (lab_input[2, :, :] - input_mean_b) * (target_std_b / input_std_b) + target_mean_b
        mapped_image = lab_to_srgb(lab_input.permute(1, 2, 0))
        return mapped_image

hsv_to_rgb(image)

Definition to convert HSV space to RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_rgb ( tensor ) –

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def hsv_to_rgb(image):

    """
    Definition to convert HSV space to  RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_rgb       : torch.tensor
                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    h = image[..., 0, :, :] / (2 * math.pi)
    s = image[..., 1, :, :]
    v = image[..., 2, :, :]
    hi = torch.floor(h * 6) % 6
    f = ((h * 6) % 6) - hi
    one = torch.tensor(1.0)
    p = v * (one - s)
    q = v * (one - f * s)
    t = v * (one - (one - f) * s)
    hi = hi.long()
    indices = torch.stack([hi, hi + 6, hi + 12], dim=-3)
    image_rgb = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
    image_rgb = torch.gather(image_rgb, -3, indices)
    return image_rgb

lab_to_srgb(image)

Definition to convert LAB space to SRGB color space.

Parameters:

  • image
              Input image in LAB color space[3 x m x n]
    

Returns:

  • image_srgb ( tensor ) –

    Output image in SRGB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def lab_to_srgb(image):
    """
    Definition to convert LAB space to SRGB color space. 

    Parameters
    ----------
    image           : torch.tensor
                      Input image in LAB color space[3 x m x n]
    Returns
    -------
    image_srgb     : torch.tensor
                      Output image in SRGB color space [3 x m x n].
    """

    if image.shape[-1] == 3:
        input_color = image.permute(2, 0, 1)  # C(H*W)
    else:
        input_color = image
    # lab ---> xyz
    reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]], dtype=torch.float32)
    y = (input_color[0:1, :, :] + 16) / 116
    a =  input_color[1:2, :, :] / 500
    b =  input_color[2:3, :, :] / 200
    x = y + a
    z = y - b
    xyz = torch.cat((x, y, z), 0)
    delta = 6 / 29
    factor = 3 * delta * delta
    xyz = torch.where(xyz > delta,  xyz ** 3, factor * (xyz - 4 / 29))
    xyz_color = xyz * reference_illuminant
    # xyz ---> linear rgb
    a11 = 3.241003275
    a12 = -1.537398934
    a13 = -0.498615861
    a21 = -0.969224334
    a22 = 1.875930071
    a23 = 0.041554224
    a31 = 0.055639423
    a32 = -0.204011202
    a33 = 1.057148933
    A = torch.tensor([[a11, a12, a13],
                  [a21, a22, a23],
                  [a31, a32, a33]], dtype=torch.float32)

    xyz_color = xyz_color.permute(2, 0, 1) # C(H*W)
    linear_rgb_color = torch.matmul(A, xyz_color)
    linear_rgb_color = linear_rgb_color.permute(1, 2, 0)
    # linear rgb ---> srgb
    limit = 0.0031308
    image_srgb = torch.where(linear_rgb_color > limit, 1.055 * (linear_rgb_color ** (1.0 / 2.4)) - 0.055, 12.92 * linear_rgb_color)
    return image_srgb

linear_rgb_to_rgb(image, threshold=0.0031308)

Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    
  • threshold
              Threshold used in calculations.
    

Returns:

  • image_linear ( tensor ) –

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_rgb(image, threshold = 0.0031308):
    """
    Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    threshold       : float
                      Threshold used in calculations.

    Returns
    -------
    image_linear    : torch.tensor
                      Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    image_linear =  torch.where(image > threshold, 1.055 * torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image)
    return image_linear

linear_rgb_to_xyz(image)

Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_xyz ( tensor ) –

    Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_xyz(image):
    """
    Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

    Parameters
    ----------
    image           : torch.tensor
                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_xyz       : torch.tensor
                      Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    a11 = 0.412453
    a12 = 0.357580
    a13 = 0.180423
    a21 = 0.212671
    a22 = 0.715160
    a23 = 0.072169
    a31 = 0.019334
    a32 = 0.119193
    a33 = 0.950227
    M = torch.tensor([[a11, a12, a13], 
                      [a21, a22, a23],
                      [a31, a32, a33]])
    size = image.size()
    image = image.reshape(size[0], size[1], size[2]*size[3])  # NC(HW)
    image_xyz = torch.matmul(M, image)
    image_xyz = image_xyz.reshape(size[0], size[1], size[2], size[3])
    return image_xyz

rgb_2_ycrcb(image)

Converts an image from RGB colourspace to YCrCb colourspace.

Parameters:

  • image
      Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    

Returns:

  • ycrcb ( tensor ) –

    Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_2_ycrcb(image):
    """
    Converts an image from RGB colourspace to YCrCb colourspace.

    Parameters
    ----------
    image   : torch.tensor
              Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].

    Returns
    -------

    ycrcb   : torch.tensor
              Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
       image = image.unsqueeze(0)
    ycrcb = torch.zeros(image.size()).to(image.device)
    ycrcb[:, 0, :, :] = 0.299 * image[:, 0, :, :] + 0.587 * \
        image[:, 1, :, :] + 0.114 * image[:, 2, :, :]
    ycrcb[:, 1, :, :] = 0.5 + 0.713 * (image[:, 0, :, :] - ycrcb[:, 0, :, :])
    ycrcb[:, 2, :, :] = 0.5 + 0.564 * (image[:, 2, :, :] - ycrcb[:, 0, :, :])
    return ycrcb

rgb_to_hsv(image, eps=1e-08)

Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_hsv ( tensor ) –

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_hsv(image, eps: float = 1e-8):

    """
    Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_hsv       : torch.tensor
                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    max_rgb, argmax_rgb = image.max(-3)
    min_rgb, argmin_rgb = image.min(-3)
    deltac = max_rgb - min_rgb
    v = max_rgb
    s = deltac / (max_rgb + eps)
    deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
    rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
    h1 = bc - gc
    h2 = (rc - bc) + 2.0 * deltac
    h3 = (gc - rc) + 4.0 * deltac
    h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
    h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
    h = (h / 6.0) % 1.0
    h = 2.0 * math.pi * h 
    image_hsv = torch.stack((h, s, v), dim=-3)
    return image_hsv

rgb_to_linear_rgb(image, threshold=0.0031308)

Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image
              Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    
  • threshold
              Threshold used in calculations.
    

Returns:

  • image_linear ( tensor ) –

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_linear_rgb(image, threshold = 0.0031308):
    """
    Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    threshold       : float
                      Threshold used in calculations.

    Returns
    -------
    image_linear    : torch.tensor
                      Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    image_linear = torch.where(image > 0.04045, torch.pow(((image + 0.055) / 1.055), 2.4), image / 12.92)
    return image_linear

srgb_to_lab(image)

Definition to convert SRGB space to LAB color space.

Parameters:

  • image
              Input image in SRGB color space[3 x m x n]
    

Returns:

  • image_lab ( tensor ) –

    Output image in LAB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def srgb_to_lab(image):    
    """
    Definition to convert SRGB space to LAB color space. 

    Parameters
    ----------
    image           : torch.tensor
                      Input image in SRGB color space[3 x m x n]
    Returns
    -------
    image_lab       : torch.tensor
                      Output image in LAB color space [3 x m x n].
    """
    if image.shape[-1] == 3:
        input_color = image.permute(2, 0, 1)  # C(H*W)
    else:
        input_color = image
    # rgb ---> linear rgb
    limit = 0.04045        
    # linear rgb ---> xyz
    linrgb_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92)

    a11 = 10135552 / 24577794
    a12 = 8788810  / 24577794
    a13 = 4435075  / 24577794
    a21 = 2613072  / 12288897
    a22 = 8788810  / 12288897
    a23 = 887015   / 12288897
    a31 = 1425312  / 73733382
    a32 = 8788810  / 73733382
    a33 = 70074185 / 73733382

    A = torch.tensor([[a11, a12, a13],
                    [a21, a22, a23],
                    [a31, a32, a33]], dtype=torch.float32)

    linrgb_color = linrgb_color.permute(2, 0, 1) # C(H*W)
    xyz_color = torch.matmul(A, linrgb_color)
    xyz_color = xyz_color.permute(1, 2, 0)
    # xyz ---> lab
    inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]], dtype=torch.float32)
    input_color = xyz_color * inv_reference_illuminant
    delta = 6 / 29
    delta_square = delta * delta
    delta_cube = delta * delta_square
    factor = 1 / (3 * delta_square)

    input_color = torch.where(input_color > delta_cube, torch.pow(input_color, 1 / 3), (factor * input_color + 4 / 29))

    l = 116 * input_color[1:2, :, :] - 16
    a = 500 * (input_color[0:1,:, :] - input_color[1:2, :, :])
    b = 200 * (input_color[1:2, :, :] - input_color[2:3, :, :])

    image_lab = torch.cat((l, a, b), 0)
    return image_lab    

xyz_to_linear_rgb(image)

Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image
               Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_linear_rgb ( tensor ) –

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def xyz_to_linear_rgb(image):
    """
    Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

    Parameters
    ----------
    image            : torch.tensor
                       Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_linear_rgb : torch.tensor
                       Output image in linear RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    a11 = 3.240479
    a12 = -1.537150
    a13 = -0.498535
    a21 = -0.969256 
    a22 = 1.875992 
    a23 = 0.041556
    a31 = 0.055648
    a32 = -0.204043
    a33 = 1.057311
    M = torch.tensor([[a11, a12, a13], 
                      [a21, a22, a23],
                      [a31, a32, a33]])
    size = image.size()
    image = image.reshape(size[0], size[1], size[2]*size[3])
    image_linear_rgb = torch.matmul(M, image)
    image_linear_rgb = image_linear_rgb.reshape(size[0], size[1], size[2], size[3])
    return image_linear_rgb

ycrcb_2_rgb(image)

Converts an image from YCrCb colourspace to RGB colourspace.

Parameters:

  • image
      Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    

Returns:

  • rgb ( tensor ) –

    Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def ycrcb_2_rgb(image):
    """
    Converts an image from YCrCb colourspace to RGB colourspace.

    Parameters
    ----------
    image   : torch.tensor
              Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].

    Returns
    -------
    rgb     : torch.tensor
              Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
       image = image.unsqueeze(0)
    rgb = torch.zeros(image.size(), device=image.device)
    rgb[:, 0, :, :] = image[:, 0, :, :] + 1.403 * (image[:, 1, :, :] - 0.5)
    rgb[:, 1, :, :] = image[:, 0, :, :] - 0.714 * \
        (image[:, 1, :, :] - 0.5) - 0.344 * (image[:, 2, :, :] - 0.5)
    rgb[:, 2, :, :] = image[:, 0, :, :] + 1.773 * (image[:, 2, :, :] - 0.5)
    return rgb

make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)

Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction.

Parameters:

  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • map ( tensor ) –

    The computed 3D location map, of size 3xWxH.

Source code in odak/learn/perception/foveation.py
def make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
    """ 
    Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to
    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
    perpendicular to the viewing direction.

    Parameters
    ----------

    image_pixel_size        : tuple of ints 
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float 
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    map                     : torch.tensor
                                The computed 3D location map, of size 3xWxH.
    """
    real_image_height = (real_image_width /
                         image_pixel_size[-1]) * image_pixel_size[-2]
    x_coords = torch.linspace(-0.5, 0.5, image_pixel_size[-1])*real_image_width
    x_coords = x_coords[None, None, :].repeat(1, image_pixel_size[-2], 1)
    y_coords = torch.linspace(-0.5, 0.5,
                              image_pixel_size[-2])*real_image_height
    y_coords = y_coords[None, :, None].repeat(1, 1, image_pixel_size[-1])
    z_coords = torch.ones(
        (1, image_pixel_size[-2], image_pixel_size[-1])) * real_viewing_distance

    return torch.cat([x_coords, y_coords, z_coords], dim=0)

make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)

Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output in radians.

Parameters:

  • gaze_location
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                        image coordinates (ranging from 0 to 1)
    
  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • eccentricity_map ( tensor ) –

    The computed eccentricity map, of size WxH.

  • distance_map ( tensor ) –

    The computed distance map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
    """ 
    Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to
    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
    perpendicular to the viewing direction. Output in radians.

    Parameters
    ----------

    gaze_location           : tuple of floats
                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                                image coordinates (ranging from 0 to 1)
    image_pixel_size        : tuple of ints
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    eccentricity_map        : torch.tensor
                                The computed eccentricity map, of size WxH.
    distance_map            : torch.tensor
                                The computed distance map, of size WxH.
    """
    real_image_height = (real_image_width /
                         image_pixel_size[-1]) * image_pixel_size[-2]
    location_map = make_3d_location_map(
        image_pixel_size, real_image_width, real_viewing_distance)
    distance_map = torch.sqrt(torch.sum(location_map*location_map, dim=0))
    direction_map = location_map / distance_map

    gaze_location_3d = torch.tensor([
        (gaze_location[0]*2 - 1)*real_image_width*0.5,
        (gaze_location[1]*2 - 1)*real_image_height*0.5,
        real_viewing_distance])
    gaze_dir = gaze_location_3d / \
        torch.sqrt(torch.sum(gaze_location_3d * gaze_location_3d))
    gaze_dir = gaze_dir[:, None, None]

    dot_prod_map = torch.sum(gaze_dir * direction_map, dim=0)
    dot_prod_map = torch.clamp(dot_prod_map, min=-1.0, max=1.0)
    eccentricity_map = torch.acos(dot_prod_map)

    return eccentricity_map, distance_map

make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')

This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_angles
                    Gaze direction expressed as angles, in radians.
    
  • image_pixel_size
                    Dimensions of the image in pixels, as a tuple of (height, width)
    
  • alpha
                    Parameter controlling extent of foveation
    
  • mode
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    

Returns:

  • pooling_size_map ( tensor ) –

    The computed pooling size map, of size HxW.

Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
    """ 
    This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
    to achieve the correct pooling region areas.

    Parameters
    ----------

    gaze_angles         : tuple of 2 floats
                            Gaze direction expressed as angles, in radians.
    image_pixel_size    : tuple of 2 ints
                            Dimensions of the image in pixels, as a tuple of (height, width)
    alpha               : float
                            Parameter controlling extent of foveation
    mode                : str
                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"

    Returns
    -------

    pooling_size_map        : torch.tensor
                                The computed pooling size map, of size HxW.
    """
    pooling_pixel = make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha, mode)
    import matplotlib.pyplot as plt
    pooling_lod = torch.log2(1e-6+pooling_pixel)
    pooling_lod[pooling_lod < 0] = 0
    return pooling_lod

make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')

This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images. Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image corresponds to pitch, ranging from -pi/2 to pi/2.

In this setup real_image_width and real_viewing_distance have no effect.

Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

Parameters:

  • gaze_angles
                    Gaze direction expressed as angles, in radians.
    
  • image_pixel_size
                    Dimensions of the image in pixels, as a tuple of (height, width)
    
  • alpha
                    Parameter controlling extent of foveation
    
  • mode
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    
Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
    """
    This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images.
    Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, 
    the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image
    corresponds to pitch, ranging from -pi/2 to pi/2.

    In this setup real_image_width and real_viewing_distance have no effect.

    Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

    Parameters
    ----------

    gaze_angles         : tuple of 2 floats
                            Gaze direction expressed as angles, in radians.
    image_pixel_size    : tuple of 2 ints
                            Dimensions of the image in pixels, as a tuple of (height, width)
    alpha               : float
                            Parameter controlling extent of foveation
    mode                : str
                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    """
    view_direction = torch.tensor([math.sin(gaze_angles[0])*math.cos(gaze_angles[1]), math.sin(gaze_angles[1]), math.cos(gaze_angles[0])*math.cos(gaze_angles[1])])

    yaw_angle_map = torch.linspace(-torch.pi, torch.pi, image_pixel_size[1])
    yaw_angle_map = yaw_angle_map[None,:].repeat(image_pixel_size[0], 1)[None,...]
    pitch_angle_map = torch.linspace(-torch.pi*0.5, torch.pi*0.5, image_pixel_size[0])
    pitch_angle_map = pitch_angle_map[:,None].repeat(1, image_pixel_size[1])[None,...]

    dir_map = torch.cat([torch.sin(yaw_angle_map)*torch.cos(pitch_angle_map), torch.sin(pitch_angle_map), torch.cos(yaw_angle_map)*torch.cos(pitch_angle_map)])

    # Work out the pooling region diameter in radians
    view_dot_dir = torch.sum(view_direction[:,None,None] * dir_map, dim=0)
    eccentricity = torch.acos(view_dot_dir)
    pooling_rad = alpha * eccentricity
    if mode == "quadratic":
        pooling_rad *= eccentricity

    # The actual pooling region will be an ellipse in the equirectangular image - the length of the major & minor axes
    # depend on the x & y resolution of the image. We find these two axis lengths (in pixels) and then the area of the ellipse
    pixels_per_rad_x = image_pixel_size[1] / (2*torch.pi)
    pixels_per_rad_y = image_pixel_size[0] / (torch.pi)
    pooling_axis_x = pooling_rad * pixels_per_rad_x
    pooling_axis_y = pooling_rad * pixels_per_rad_y
    area = torch.pi * pooling_axis_x * pooling_axis_y * 0.25

    # Now finally find the length of the side of a square of the same area.
    size = torch.sqrt(torch.abs(area))
    return size

make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')

This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_location
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                        image coordinates (ranging from 0 to 1)
    
  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • pooling_size_map ( tensor ) –

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
    """ 
    This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
    to achieve the correct pooling region areas.

    Parameters
    ----------

    gaze_location           : tuple of floats
                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                                image coordinates (ranging from 0 to 1)
    image_pixel_size        : tuple of ints
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    pooling_size_map        : torch.tensor
                                The computed pooling size map, of size WxH.
    """
    pooling_pixel = make_pooling_size_map_pixels(
        gaze_location, image_pixel_size, alpha, real_image_width, real_viewing_distance, mode)
    pooling_lod = torch.log2(1e-6+pooling_pixel)
    pooling_lod[pooling_lod < 0] = 0
    return pooling_lod

make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')

Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity (also in radians).

Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

Parameters:

  • gaze_location
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                        image coordinates (ranging from 0 to 1)
    
  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • pooling_size_map ( tensor ) –

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
    """ 
    Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to
    a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity
    (also in radians). 

    Assumes the viewpoint is located at the centre of the image, and the screen is 
    perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

    Parameters
    ----------

    gaze_location           : tuple of floats
                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                                image coordinates (ranging from 0 to 1)
    image_pixel_size        : tuple of ints
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    pooling_size_map        : torch.tensor
                                The computed pooling size map, of size WxH.
    """
    eccentricity, distance_to_pixel = make_eccentricity_distance_maps(
        gaze_location, image_pixel_size, real_image_width, real_viewing_distance)
    eccentricity_centre, _ = make_eccentricity_distance_maps(
        [0.5, 0.5], image_pixel_size, real_image_width, real_viewing_distance)
    pooling_rad = alpha * eccentricity
    if mode == "quadratic":
        pooling_rad *= eccentricity
    angle_min = eccentricity_centre - pooling_rad*0.5
    angle_max = eccentricity_centre + pooling_rad*0.5
    major_axis = (torch.tan(angle_max) - torch.tan(angle_min)) * \
        real_viewing_distance
    minor_axis = 2 * distance_to_pixel * torch.tan(pooling_rad*0.5)
    area = math.pi * major_axis * minor_axis * 0.25
    # Should be +ve anyway, but check to ensure we don't take sqrt of negative number
    area = torch.abs(area)
    pooling_real = torch.sqrt(area)
    pooling_pixel = (pooling_real / real_image_width) * image_pixel_size[1]
    return pooling_pixel

make_radial_map(size, gaze)

Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

Parameters:

  • size
        Dimensions of the image
    
  • gaze
        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
        image coordinates (ranging from 0 to 1)
    
Source code in odak/learn/perception/foveation.py
def make_radial_map(size, gaze):
    """ 
    Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

    Parameters
    ----------

    size    : tuple of ints
                Dimensions of the image
    gaze    : tuple of floats
                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                image coordinates (ranging from 0 to 1)
    """
    pix_gaze = [gaze[0]*size[0], gaze[1]*size[1]]
    rows = torch.linspace(0, size[0], size[0])
    rows = rows[:, None].repeat(1, size[1])
    cols = torch.linspace(0, size[1], size[1])
    cols = cols[None, :].repeat(size[0], 1)
    dist_sq = torch.pow(rows - pix_gaze[0], 2) + \
        torch.pow(cols - pix_gaze[1], 2)
    radii = torch.sqrt(dist_sq)
    return radii/torch.max(radii)

MetamericLoss

The MetamericLoss class provides a perceptual loss function.

Rather than exactly match the source image to the target, it tries to ensure the source is a metamer to the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metameric_loss.py
class MetamericLoss():
    """
    The `MetamericLoss` class provides a perceptual loss function.

    Rather than exactly match the source image to the target, it tries to ensure the source is a *metamer* to the target image.

    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
                 real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
                 n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
                 use_fullres_l0=False, equi=False):
        """
        Parameters
        ----------

        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        n_pyramid_levels        : int 
                                    Number of levels of the steerable pyramid. Note that the image is padded
                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                    too high will slow down the calculation a lot.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        n_orientations          : int 
                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                    Increasing this will increase runtime.
        use_l2_foveal_loss      : bool 
                                    If true, for all the pixels that have pooling size 1 pixel in the 
                                    largest scale will use direct L2 against target rather than pooling over pyramid levels.
                                    In practice this gives better results when the loss is used for holography.
        fovea_weight            : float 
                                    A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
        use_radial_weight       : bool 
                                    If True, will apply a radial weighting when calculating the difference between
                                    the source and target stats maps. This weights stats closer to the fovea more than those
                                    further away.
        use_fullres_l0          : bool 
                                    If true, stats for the lowpass residual are replaced with blurred versions
                                    of the full-resolution source and target images.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.device = device
        self.pyramid_maker = None
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.blurs = None
        self.n_pyramid_levels = n_pyramid_levels
        self.n_orientations = n_orientations
        self.mode = mode
        self.use_l2_foveal_loss = use_l2_foveal_loss
        self.fovea_weight = fovea_weight
        self.use_radial_weight = use_radial_weight
        self.use_fullres_l0 = use_fullres_l0
        self.equi = equi
        if self.use_fullres_l0 and self.use_l2_foveal_loss:
            raise Exception(
                "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")

    def calc_statsmaps(self, image, gaze=None, alpha=0.01, real_image_width=0.3,
                       real_viewing_distance=0.6, mode="quadratic", equi=False):

        if self.pyramid_maker is None or \
                self.pyramid_maker.device != self.device or \
                len(self.pyramid_maker.band_filters) != self.n_orientations or\
                self.pyramid_maker.filt_h0.size(0) != image.size(1):
            self.pyramid_maker = SpatialSteerablePyramid(
                use_bilinear_downup=False, n_channels=image.size(1),
                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)

        if self.blurs is None or len(self.blurs) != self.n_pyramid_levels:
            self.blurs = [RadiallyVaryingBlur()
                          for i in range(self.n_pyramid_levels)]

        def find_stats(image_pyr_level, blur):
            image_means = blur.blur(
                image_pyr_level, alpha, real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)
            image_meansq = blur.blur(image_pyr_level*image_pyr_level, alpha,
                                     real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)

            image_vars = image_meansq - (image_means*image_means)
            image_vars[image_vars < 1e-7] = 1e-7
            image_std = torch.sqrt(image_vars)
            if torch.any(torch.isnan(image_means)):
                print(image_means)
                raise Exception("NaN in image means!")
            if torch.any(torch.isnan(image_std)):
                print(image_std)
                raise Exception("NaN in image stdevs!")
            if self.use_fullres_l0:
                mask = blur.lod_map > 1e-6
                mask = mask[None, None, ...]
                if image_means.size(1) > 1:
                    mask = mask.repeat(1, image_means.size(1), 1, 1)
                matte = torch.zeros_like(image_means)
                matte[mask] = 1.0
                return image_means * matte, image_std * matte
            return image_means, image_std
        output_stats = []
        image_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)
        means, variances = find_stats(image_pyramid[0]['h'], self.blurs[0])
        if self.use_l2_foveal_loss:
            self.fovea_mask = torch.zeros(image.size(), device=image.device)
            for i in range(self.fovea_mask.size(1)):
                self.fovea_mask[0, i, ...] = 1.0 - \
                    (self.blurs[0].lod_map / torch.max(self.blurs[0].lod_map))
                self.fovea_mask[0, i, self.blurs[0].lod_map < 1e-6] = 1.0
            self.fovea_mask = torch.pow(self.fovea_mask, 10.0)
            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, scale_factor=0.125, mode="area")
            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, size=(image.size(-2), image.size(-1)), mode="bilinear")
            periphery_mask = 1.0 - self.fovea_mask
            self.periphery_mask = periphery_mask.clone()
            output_stats.append(means * periphery_mask)
            output_stats.append(variances * periphery_mask)
        else:
            output_stats.append(means)
            output_stats.append(variances)

        for l in range(0, len(image_pyramid)-1):
            for o in range(len(image_pyramid[l]['b'])):
                means, variances = find_stats(
                    image_pyramid[l]['b'][o], self.blurs[l])
                if self.use_l2_foveal_loss:
                    output_stats.append(means * periphery_mask)
                    output_stats.append(variances * periphery_mask)
                else:
                    output_stats.append(means)
                    output_stats.append(variances)
            if self.use_l2_foveal_loss:
                periphery_mask = torch.nn.functional.interpolate(
                    periphery_mask, scale_factor=0.5, mode="area", recompute_scale_factor=False)

        if self.use_l2_foveal_loss:
            output_stats.append(image_pyramid[-1]["l"] * periphery_mask)
        elif self.use_fullres_l0:
            output_stats.append(self.blurs[0].blur(
                image, alpha, real_image_width, real_viewing_distance, gaze, mode))
        else:
            output_stats.append(image_pyramid[-1]["l"])
        return output_stats

    def metameric_loss_stats(self, statsmap_a, statsmap_b, gaze):
        loss = 0.0
        for a, b in zip(statsmap_a, statsmap_b):
            if self.use_radial_weight:
                radii = make_radial_map(
                    [a.size(-2), a.size(-1)], gaze).to(a.device)
                weights = 1.1 - (radii * radii * radii * radii)
                weights = weights[None, None, ...].repeat(1, a.size(1), 1, 1)
                loss += torch.nn.MSELoss()(weights*a, weights*b)
            else:
                loss += torch.nn.MSELoss()(a, b)
        loss /= len(statsmap_a)
        return loss

    def visualise_loss_map(self, image_stats):
        loss_map = torch.zeros(image_stats[0].size()[-2:])
        for i in range(len(image_stats)):
            stats = image_stats[i]
            target_stats = self.target_stats[i]
            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
            loss_map += stat_mse_map[0, 0, ...]
        self.loss_map = loss_map

    def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
        """ 
        Calculates the Metameric Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        image_colorspace    : str
                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
                                accepted values: RGB, YCrCb.
        gaze                : list
                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
        visualise_loss      : bool
                                Shows a heatmap indicating which parts of the image contributed most to the loss. 

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamericLoss", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
        # If input is RGB, convert to YCrCb.
        if image.size(1) == 3 and image_colorspace == "RGB":
            image = rgb_2_ycrcb(image)
            target = rgb_2_ycrcb(target)
        if self.target is None:
            self.target = torch.zeros(target.shape).to(target.device)
        if type(target) == type(self.target):
            if not torch.all(torch.eq(target, self.target)):
                self.target = target.detach().clone()
                self.target_stats = self.calc_statsmaps(
                    self.target,
                    gaze=gaze,
                    alpha=self.alpha,
                    real_image_width=self.real_image_width,
                    real_viewing_distance=self.real_viewing_distance,
                    mode=self.mode
                )
                self.target = target.detach().clone()
            image_stats = self.calc_statsmaps(
                image,
                gaze=gaze,
                alpha=self.alpha,
                real_image_width=self.real_image_width,
                real_viewing_distance=self.real_viewing_distance,
                mode=self.mode
            )
            if visualise_loss:
                self.visualise_loss_map(image_stats)
            if self.use_l2_foveal_loss:
                peripheral_loss = self.metameric_loss_stats(
                    image_stats, self.target_stats, gaze)
                foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
                # New weighting - evenly weight fovea and periphery.
                loss = peripheral_loss + self.fovea_weight * foveal_loss
            else:
                loss = self.metameric_loss_stats(
                    image_stats, self.target_stats, gaze)
            return loss
        else:
            raise Exception("Target of incorrect type")

    def to(self, device):
        self.device = device
        return self

__call__(image, target, gaze=[0.5, 0.5], image_colorspace='RGB', visualise_loss=False)

Calculates the Metameric Loss.

Parameters:

  • image
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • image_colorspace
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
                    accepted values: RGB, YCrCb.
    
  • gaze
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    
  • visualise_loss
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metameric_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
    """ 
    Calculates the Metameric Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    image_colorspace    : str
                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
                            accepted values: RGB, YCrCb.
    gaze                : list
                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    visualise_loss      : bool
                            Shows a heatmap indicating which parts of the image contributed most to the loss. 

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamericLoss", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
    # If input is RGB, convert to YCrCb.
    if image.size(1) == 3 and image_colorspace == "RGB":
        image = rgb_2_ycrcb(image)
        target = rgb_2_ycrcb(target)
    if self.target is None:
        self.target = torch.zeros(target.shape).to(target.device)
    if type(target) == type(self.target):
        if not torch.all(torch.eq(target, self.target)):
            self.target = target.detach().clone()
            self.target_stats = self.calc_statsmaps(
                self.target,
                gaze=gaze,
                alpha=self.alpha,
                real_image_width=self.real_image_width,
                real_viewing_distance=self.real_viewing_distance,
                mode=self.mode
            )
            self.target = target.detach().clone()
        image_stats = self.calc_statsmaps(
            image,
            gaze=gaze,
            alpha=self.alpha,
            real_image_width=self.real_image_width,
            real_viewing_distance=self.real_viewing_distance,
            mode=self.mode
        )
        if visualise_loss:
            self.visualise_loss_map(image_stats)
        if self.use_l2_foveal_loss:
            peripheral_loss = self.metameric_loss_stats(
                image_stats, self.target_stats, gaze)
            foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
            # New weighting - evenly weight fovea and periphery.
            loss = peripheral_loss + self.fovea_weight * foveal_loss
        else:
            loss = self.metameric_loss_stats(
                image_stats, self.target_stats, gaze)
        return loss
    else:
        raise Exception("Target of incorrect type")

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, n_pyramid_levels=5, mode='quadratic', n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False, use_fullres_l0=False, equi=False)

Parameters:

  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • n_pyramid_levels
                        Number of levels of the steerable pyramid. Note that the image is padded
                        so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                        too high will slow down the calculation a lot.
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • n_orientations
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                        Increasing this will increase runtime.
    
  • use_l2_foveal_loss
                        If true, for all the pixels that have pooling size 1 pixel in the 
                        largest scale will use direct L2 against target rather than pooling over pyramid levels.
                        In practice this gives better results when the loss is used for holography.
    
  • fovea_weight
                        A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    
  • use_radial_weight
                        If True, will apply a radial weighting when calculating the difference between
                        the source and target stats maps. This weights stats closer to the fovea more than those
                        further away.
    
  • use_fullres_l0
                        If true, stats for the lowpass residual are replaced with blurred versions
                        of the full-resolution source and target images.
    
  • equi
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/metameric_loss.py
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
             real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
             n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
             use_fullres_l0=False, equi=False):
    """
    Parameters
    ----------

    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    n_pyramid_levels        : int 
                                Number of levels of the steerable pyramid. Note that the image is padded
                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                too high will slow down the calculation a lot.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    n_orientations          : int 
                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                Increasing this will increase runtime.
    use_l2_foveal_loss      : bool 
                                If true, for all the pixels that have pooling size 1 pixel in the 
                                largest scale will use direct L2 against target rather than pooling over pyramid levels.
                                In practice this gives better results when the loss is used for holography.
    fovea_weight            : float 
                                A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    use_radial_weight       : bool 
                                If True, will apply a radial weighting when calculating the difference between
                                the source and target stats maps. This weights stats closer to the fovea more than those
                                further away.
    use_fullres_l0          : bool 
                                If true, stats for the lowpass residual are replaced with blurred versions
                                of the full-resolution source and target images.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.device = device
    self.pyramid_maker = None
    self.alpha = alpha
    self.real_image_width = real_image_width
    self.real_viewing_distance = real_viewing_distance
    self.blurs = None
    self.n_pyramid_levels = n_pyramid_levels
    self.n_orientations = n_orientations
    self.mode = mode
    self.use_l2_foveal_loss = use_l2_foveal_loss
    self.fovea_weight = fovea_weight
    self.use_radial_weight = use_radial_weight
    self.use_fullres_l0 = use_fullres_l0
    self.equi = equi
    if self.use_fullres_l0 and self.use_l2_foveal_loss:
        raise Exception(
            "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")

MetamericLossUniform

Measures metameric loss between a given image and a metamer of the given target image. This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.

Source code in odak/learn/perception/metameric_loss_uniform.py
class MetamericLossUniform():
    """
    Measures metameric loss between a given image and a metamer of the given target image.
    This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.
    """

    def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
        """

        Parameters
        ----------
        pooling_size            : int
                                  Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
        n_pyramid_levels        : int 
                                  Number of levels of the steerable pyramid. Note that the image is padded
                                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                  too high will slow down the calculation a lot.
        n_orientations          : int 
                                  Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                  Increasing this will increase runtime.

        """
        self.target = None
        self.device = device
        self.pyramid_maker = None
        self.pooling_size = pooling_size
        self.n_pyramid_levels = n_pyramid_levels
        self.n_orientations = n_orientations

    def calc_statsmaps(self, image, pooling_size):

        if self.pyramid_maker is None or \
                self.pyramid_maker.device != self.device or \
                len(self.pyramid_maker.band_filters) != self.n_orientations or\
                self.pyramid_maker.filt_h0.size(0) != image.size(1):
            self.pyramid_maker = SpatialSteerablePyramid(
                use_bilinear_downup=False, n_channels=image.size(1),
                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)


        def find_stats(image_pyr_level, pooling_size):
            image_means = uniform_blur(image_pyr_level, pooling_size)
            image_meansq = uniform_blur(image_pyr_level*image_pyr_level, pooling_size)
            image_vars = image_meansq - (image_means*image_means)
            image_vars[image_vars < 1e-7] = 1e-7
            image_std = torch.sqrt(image_vars)
            if torch.any(torch.isnan(image_means)):
                print(image_means)
                raise Exception("NaN in image means!")
            if torch.any(torch.isnan(image_std)):
                print(image_std)
                raise Exception("NaN in image stdevs!")
            return image_means, image_std

        output_stats = []
        image_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)
        curr_pooling_size = pooling_size
        means, variances = find_stats(image_pyramid[0]['h'], curr_pooling_size)
        output_stats.append(means)
        output_stats.append(variances)

        for l in range(0, len(image_pyramid)-1):
            for o in range(len(image_pyramid[l]['b'])):
                means, variances = find_stats(
                    image_pyramid[l]['b'][o], curr_pooling_size)
                output_stats.append(means)
                output_stats.append(variances)
            curr_pooling_size /= 2

        output_stats.append(image_pyramid[-1]["l"])
        return output_stats

    def metameric_loss_stats(self, statsmap_a, statsmap_b):
        loss = 0.0
        for a, b in zip(statsmap_a, statsmap_b):
            loss += torch.nn.MSELoss()(a, b)
        loss /= len(statsmap_a)
        return loss

    def visualise_loss_map(self, image_stats):
        loss_map = torch.zeros(image_stats[0].size()[-2:])
        for i in range(len(image_stats)):
            stats = image_stats[i]
            target_stats = self.target_stats[i]
            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
            loss_map += stat_mse_map[0, 0, ...]
        self.loss_map = loss_map

    def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
        """ 
        Calculates the Metameric Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        image_colorspace    : str
                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
                                accepted values: RGB, YCrCb.
        visualise_loss      : bool
                                Shows a heatmap indicating which parts of the image contributed most to the loss. 

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamericLossUniform", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
        # If input is RGB, convert to YCrCb.
        if image.size(1) == 3 and image_colorspace == "RGB":
            image = rgb_2_ycrcb(image)
            target = rgb_2_ycrcb(target)
        if self.target is None:
            self.target = torch.zeros(target.shape).to(target.device)
        if type(target) == type(self.target):
            if not torch.all(torch.eq(target, self.target)):
                self.target = target.detach().clone()
                self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
                self.target = target.detach().clone()
            image_stats = self.calc_statsmaps(image, self.pooling_size)

            if visualise_loss:
                self.visualise_loss_map(image_stats)
            loss = self.metameric_loss_stats(
                image_stats, self.target_stats)
            return loss
        else:
            raise Exception("Target of incorrect type")

    def gen_metamer(self, image):
        """ 
        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
        This function can be used on its own to generate a metamer for a desired image.

        Parameters
        ----------
        image   : torch.tensor
                  Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)

        Returns
        -------
        metamer : torch.tensor
                  The generated metamer image
        """
        image = rgb_2_ycrcb(image)
        image_size = image.size()
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)

        target_stats = self.calc_statsmaps(
            image, self.pooling_size)
        target_means = target_stats[::2]
        target_stdevs = target_stats[1::2]
        torch.manual_seed(0)
        noise_image = torch.rand_like(image)
        noise_pyramid = self.pyramid_maker.construct_pyramid(
            noise_image, self.n_pyramid_levels)
        input_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)

        def match_level(input_level, target_mean, target_std):
            level = input_level.clone()
            level -= torch.mean(level)
            input_std = torch.sqrt(torch.mean(level * level))
            eps = 1e-6
            # Safeguard against divide by zero
            input_std[input_std < eps] = eps
            level /= input_std
            level *= target_std
            level += target_mean
            return level

        nbands = len(noise_pyramid[0]["b"])
        noise_pyramid[0]["h"] = match_level(
            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
        for l in range(len(noise_pyramid)-1):
            for b in range(nbands):
                noise_pyramid[l]["b"][b] = match_level(
                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

        metamer = self.pyramid_maker.reconstruct_from_pyramid(
            noise_pyramid)
        metamer = ycrcb_2_rgb(metamer)
        # Crop to remove any padding
        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
        return metamer

    def to(self, device):
        self.device = device
        return self

__call__(image, target, image_colorspace='RGB', visualise_loss=False)

Calculates the Metameric Loss.

Parameters:

  • image
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • image_colorspace
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
                    accepted values: RGB, YCrCb.
    
  • visualise_loss
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metameric_loss_uniform.py
def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
    """ 
    Calculates the Metameric Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    image_colorspace    : str
                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
                            accepted values: RGB, YCrCb.
    visualise_loss      : bool
                            Shows a heatmap indicating which parts of the image contributed most to the loss. 

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamericLossUniform", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
    # If input is RGB, convert to YCrCb.
    if image.size(1) == 3 and image_colorspace == "RGB":
        image = rgb_2_ycrcb(image)
        target = rgb_2_ycrcb(target)
    if self.target is None:
        self.target = torch.zeros(target.shape).to(target.device)
    if type(target) == type(self.target):
        if not torch.all(torch.eq(target, self.target)):
            self.target = target.detach().clone()
            self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
            self.target = target.detach().clone()
        image_stats = self.calc_statsmaps(image, self.pooling_size)

        if visualise_loss:
            self.visualise_loss_map(image_stats)
        loss = self.metameric_loss_stats(
            image_stats, self.target_stats)
        return loss
    else:
        raise Exception("Target of incorrect type")

__init__(device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2)

Parameters:

  • pooling_size
                      Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    
  • n_pyramid_levels
                      Number of levels of the steerable pyramid. Note that the image is padded
                      so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                      too high will slow down the calculation a lot.
    
  • n_orientations
                      Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                      Increasing this will increase runtime.
    
Source code in odak/learn/perception/metameric_loss_uniform.py
def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
    """

    Parameters
    ----------
    pooling_size            : int
                              Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    n_pyramid_levels        : int 
                              Number of levels of the steerable pyramid. Note that the image is padded
                              so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                              too high will slow down the calculation a lot.
    n_orientations          : int 
                              Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                              Increasing this will increase runtime.

    """
    self.target = None
    self.device = device
    self.pyramid_maker = None
    self.pooling_size = pooling_size
    self.n_pyramid_levels = n_pyramid_levels
    self.n_orientations = n_orientations

gen_metamer(image)

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image
      Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    

Returns:

  • metamer ( tensor ) –

    The generated metamer image

Source code in odak/learn/perception/metameric_loss_uniform.py
def gen_metamer(self, image):
    """ 
    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
    This function can be used on its own to generate a metamer for a desired image.

    Parameters
    ----------
    image   : torch.tensor
              Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)

    Returns
    -------
    metamer : torch.tensor
              The generated metamer image
    """
    image = rgb_2_ycrcb(image)
    image_size = image.size()
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)

    target_stats = self.calc_statsmaps(
        image, self.pooling_size)
    target_means = target_stats[::2]
    target_stdevs = target_stats[1::2]
    torch.manual_seed(0)
    noise_image = torch.rand_like(image)
    noise_pyramid = self.pyramid_maker.construct_pyramid(
        noise_image, self.n_pyramid_levels)
    input_pyramid = self.pyramid_maker.construct_pyramid(
        image, self.n_pyramid_levels)

    def match_level(input_level, target_mean, target_std):
        level = input_level.clone()
        level -= torch.mean(level)
        input_std = torch.sqrt(torch.mean(level * level))
        eps = 1e-6
        # Safeguard against divide by zero
        input_std[input_std < eps] = eps
        level /= input_std
        level *= target_std
        level += target_mean
        return level

    nbands = len(noise_pyramid[0]["b"])
    noise_pyramid[0]["h"] = match_level(
        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
    for l in range(len(noise_pyramid)-1):
        for b in range(nbands):
            noise_pyramid[l]["b"][b] = match_level(
                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

    metamer = self.pyramid_maker.reconstruct_from_pyramid(
        noise_pyramid)
    metamer = ycrcb_2_rgb(metamer)
    # Crop to remove any padding
    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
    return metamer

MetamerMSELoss

The MetamerMSELoss class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

Please note this is different to MetamericLoss which optimises the source image to be any metamer of the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metamer_mse_loss.py
class MetamerMSELoss():
    """ 
    The `MetamerMSELoss` class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

    Please note this is different to `MetamericLoss` which optimises the source image to be any metamer of the target image.

    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device("cpu"),
                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
                 n_pyramid_levels=5, n_orientations=2, equi=False):
        """
        Parameters
        ----------
        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        n_pyramid_levels        : int 
                                    Number of levels of the steerable pyramid. Note that the image is padded
                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                    too high will slow down the calculation a lot.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        n_orientations          : int 
                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                    Increasing this will increase runtime.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.target_metamer = None
        self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
                                            real_viewing_distance=real_viewing_distance,
                                            n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
        self.loss_func = torch.nn.MSELoss()
        self.noise = None

    def gen_metamer(self, image, gaze):
        """ 
        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
        This function can be used on its own to generate a metamer for a desired image.

        Parameters
        ----------
        image   : torch.tensor
                Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
        gaze    : list
                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        metamer : torch.tensor
                The generated metamer image
        """
        image = rgb_2_ycrcb(image)
        image_size = image.size()
        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)

        target_stats = self.metameric_loss.calc_statsmaps(
            image, gaze=gaze, alpha=self.metameric_loss.alpha)
        target_means = target_stats[::2]
        target_stdevs = target_stats[1::2]
        if self.noise is None or self.noise.size() != image.size():
            torch.manual_seed(0)
            noise_image = torch.rand_like(image)
        noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
            noise_image, self.metameric_loss.n_pyramid_levels)
        input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
            image, self.metameric_loss.n_pyramid_levels)

        def match_level(input_level, target_mean, target_std):
            level = input_level.clone()
            level -= torch.mean(level)
            input_std = torch.sqrt(torch.mean(level * level))
            eps = 1e-6
            # Safeguard against divide by zero
            input_std[input_std < eps] = eps
            level /= input_std
            level *= target_std
            level += target_mean
            return level

        nbands = len(noise_pyramid[0]["b"])
        noise_pyramid[0]["h"] = match_level(
            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
        for l in range(len(noise_pyramid)-1):
            for b in range(nbands):
                noise_pyramid[l]["b"][b] = match_level(
                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

        metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
            noise_pyramid)
        metamer = ycrcb_2_rgb(metamer)
        # Crop to remove any padding
        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
        return metamer

    def __call__(self, image, target, gaze=[0.5, 0.5]):
        """ 
        Calculates the Metamer MSE Loss.

        Parameters
        ----------
        image   : torch.tensor
                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target  : torch.tensor
                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        gaze    : list
                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamerMSELoss", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)

        if target is not self.target or self.target is None:
            self.target_metamer = self.gen_metamer(target, gaze)
            self.target = target

        return self.loss_func(image, self.target_metamer)

    def to(self, device):
        self.metameric_loss = self.metameric_loss.to(device)
        return self

__call__(image, target, gaze=[0.5, 0.5])

Calculates the Metamer MSE Loss.

Parameters:

  • image
    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metamer_mse_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):
    """ 
    Calculates the Metamer MSE Loss.

    Parameters
    ----------
    image   : torch.tensor
            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target  : torch.tensor
            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    gaze    : list
            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamerMSELoss", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)

    if target is not self.target or self.target is None:
        self.target_metamer = self.gen_metamer(target, gaze)
        self.target = target

    return self.loss_func(image, self.target_metamer)

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', n_pyramid_levels=5, n_orientations=2, equi=False)

Parameters:

  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • n_pyramid_levels
                        Number of levels of the steerable pyramid. Note that the image is padded
                        so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                        too high will slow down the calculation a lot.
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • n_orientations
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                        Increasing this will increase runtime.
    
  • equi
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/metamer_mse_loss.py
def __init__(self, device=torch.device("cpu"),
             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
             n_pyramid_levels=5, n_orientations=2, equi=False):
    """
    Parameters
    ----------
    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    n_pyramid_levels        : int 
                                Number of levels of the steerable pyramid. Note that the image is padded
                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                too high will slow down the calculation a lot.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    n_orientations          : int 
                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                Increasing this will increase runtime.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.target_metamer = None
    self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
                                        real_viewing_distance=real_viewing_distance,
                                        n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
    self.loss_func = torch.nn.MSELoss()
    self.noise = None

gen_metamer(image, gaze)

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image
    Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • metamer ( tensor ) –

    The generated metamer image

Source code in odak/learn/perception/metamer_mse_loss.py
def gen_metamer(self, image, gaze):
    """ 
    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
    This function can be used on its own to generate a metamer for a desired image.

    Parameters
    ----------
    image   : torch.tensor
            Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    gaze    : list
            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    metamer : torch.tensor
            The generated metamer image
    """
    image = rgb_2_ycrcb(image)
    image_size = image.size()
    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)

    target_stats = self.metameric_loss.calc_statsmaps(
        image, gaze=gaze, alpha=self.metameric_loss.alpha)
    target_means = target_stats[::2]
    target_stdevs = target_stats[1::2]
    if self.noise is None or self.noise.size() != image.size():
        torch.manual_seed(0)
        noise_image = torch.rand_like(image)
    noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
        noise_image, self.metameric_loss.n_pyramid_levels)
    input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
        image, self.metameric_loss.n_pyramid_levels)

    def match_level(input_level, target_mean, target_std):
        level = input_level.clone()
        level -= torch.mean(level)
        input_std = torch.sqrt(torch.mean(level * level))
        eps = 1e-6
        # Safeguard against divide by zero
        input_std[input_std < eps] = eps
        level /= input_std
        level *= target_std
        level += target_mean
        return level

    nbands = len(noise_pyramid[0]["b"])
    noise_pyramid[0]["h"] = match_level(
        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
    for l in range(len(noise_pyramid)-1):
        for b in range(nbands):
            noise_pyramid[l]["b"][b] = match_level(
                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

    metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
        noise_pyramid)
    metamer = ycrcb_2_rgb(metamer)
    # Crop to remove any padding
    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
    return metamer

RadiallyVaryingBlur

The RadiallyVaryingBlur class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given alpha parameter value. For more information on how the pooling sizes are computed, please see link coming soon.

The blur is accelerated by generating and sampling from MIP maps of the input image.

This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

Source code in odak/learn/perception/radially_varying_blur.py
class RadiallyVaryingBlur():
    """ 

    The `RadiallyVaryingBlur` class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given `alpha` parameter value. For more information on how the pooling sizes are computed, please see [link coming soon]().

    The blur is accelerated by generating and sampling from MIP maps of the input image.

    This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

    If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

    """

    def __init__(self):
        self.lod_map = None
        self.equi = None

    def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
        """
        Apply the radially varying blur to an image.

        Parameters
        ----------

        image                   : torch.tensor
                                    The image to blur, in NCHW format.
        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
                                    Ignored in equirectangular mode (equi==True)
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
                                    Ignored in equirectangular mode (equi==True)
        centre                  : tuple of floats
                                    The centre of the radially varying blur (the gaze location).
                                    Should be a tuple of floats containing normalised image coordinates in range [0,1]
                                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        equi                    : bool
                                    If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The centre argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]

        Returns
        -------

        output                  : torch.tensor
                                    The blurred image
        """
        size = (image.size(-2), image.size(-1))

        # LOD map caching
        if self.lod_map is None or\
                self.size != size or\
                self.n_channels != image.size(1) or\
                self.alpha != alpha or\
                self.real_image_width != real_image_width or\
                self.real_viewing_distance != real_viewing_distance or\
                self.centre != centre or\
                self.mode != mode or\
                self.equi != equi:
            if not equi:
                self.lod_map = make_pooling_size_map_lod(
                    centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
            else:
                self.lod_map = make_equi_pooling_size_map_lod(
                    centre, (image.size(-2), image.size(-1)), alpha, mode)
            self.size = size
            self.n_channels = image.size(1)
            self.alpha = alpha
            self.real_image_width = real_image_width
            self.real_viewing_distance = real_viewing_distance
            self.centre = centre
            self.lod_map = self.lod_map.to(image.device)
            self.lod_fraction = torch.fmod(self.lod_map, 1.0)
            self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
                1, image.size(1), 1, 1)
            self.mode = mode
            self.equi = equi

        if self.lod_map.device != image.device:
            self.lod_map = self.lod_map.to(image.device)
        if self.lod_fraction.device != image.device:
            self.lod_fraction = self.lod_fraction.to(image.device)

        mipmap = [image]
        while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
            mipmap.append(torch.nn.functional.interpolate(
                mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
        if mipmap[-1].size(-1) == 2:
            final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
            mipmap.append(final_mip)
        if mipmap[-1].size(-2) == 2:
            final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
            mipmap.append(final_mip)

        for l in range(len(mipmap)):
            if l == len(mipmap)-1:
                mipmap[l] = mipmap[l] * \
                    torch.ones(image.size(), device=image.device)
            else:
                for l2 in range(l-1, -1, -1):
                    mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
                        image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)

        output = torch.zeros(image.size(), device=image.device)
        for l in range(len(mipmap)):
            if l == 0:
                mask = self.lod_map < (l+1)
            elif l == len(mipmap)-1:
                mask = self.lod_map >= l
            else:
                mask = torch.logical_and(
                    self.lod_map >= l, self.lod_map < (l+1))

            if l == len(mipmap)-1:
                blended_levels = mipmap[l]
            else:
                blended_levels = (1 - self.lod_fraction) * \
                    mipmap[l] + self.lod_fraction*mipmap[l+1]
            mask = mask[None, None, ...]
            mask = mask.repeat(1, image.size(1), 1, 1)
            output[mask] = blended_levels[mask]

        return output

blur(image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode='quadratic', equi=False)

Apply the radially varying blur to an image.

Parameters:

  • image
                        The image to blur, in NCHW format.
    
  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
                        Ignored in equirectangular mode (equi==True)
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
                        Ignored in equirectangular mode (equi==True)
    
  • centre
                        The centre of the radially varying blur (the gaze location).
                        Should be a tuple of floats containing normalised image coordinates in range [0,1]
                        In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • equi
                        If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The centre argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    

Returns:

  • output ( tensor ) –

    The blurred image

Source code in odak/learn/perception/radially_varying_blur.py
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
    """
    Apply the radially varying blur to an image.

    Parameters
    ----------

    image                   : torch.tensor
                                The image to blur, in NCHW format.
    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
                                Ignored in equirectangular mode (equi==True)
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
                                Ignored in equirectangular mode (equi==True)
    centre                  : tuple of floats
                                The centre of the radially varying blur (the gaze location).
                                Should be a tuple of floats containing normalised image coordinates in range [0,1]
                                In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    equi                    : bool
                                If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The centre argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]

    Returns
    -------

    output                  : torch.tensor
                                The blurred image
    """
    size = (image.size(-2), image.size(-1))

    # LOD map caching
    if self.lod_map is None or\
            self.size != size or\
            self.n_channels != image.size(1) or\
            self.alpha != alpha or\
            self.real_image_width != real_image_width or\
            self.real_viewing_distance != real_viewing_distance or\
            self.centre != centre or\
            self.mode != mode or\
            self.equi != equi:
        if not equi:
            self.lod_map = make_pooling_size_map_lod(
                centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
        else:
            self.lod_map = make_equi_pooling_size_map_lod(
                centre, (image.size(-2), image.size(-1)), alpha, mode)
        self.size = size
        self.n_channels = image.size(1)
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.centre = centre
        self.lod_map = self.lod_map.to(image.device)
        self.lod_fraction = torch.fmod(self.lod_map, 1.0)
        self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
            1, image.size(1), 1, 1)
        self.mode = mode
        self.equi = equi

    if self.lod_map.device != image.device:
        self.lod_map = self.lod_map.to(image.device)
    if self.lod_fraction.device != image.device:
        self.lod_fraction = self.lod_fraction.to(image.device)

    mipmap = [image]
    while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
        mipmap.append(torch.nn.functional.interpolate(
            mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
    if mipmap[-1].size(-1) == 2:
        final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
        mipmap.append(final_mip)
    if mipmap[-1].size(-2) == 2:
        final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
        mipmap.append(final_mip)

    for l in range(len(mipmap)):
        if l == len(mipmap)-1:
            mipmap[l] = mipmap[l] * \
                torch.ones(image.size(), device=image.device)
        else:
            for l2 in range(l-1, -1, -1):
                mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
                    image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)

    output = torch.zeros(image.size(), device=image.device)
    for l in range(len(mipmap)):
        if l == 0:
            mask = self.lod_map < (l+1)
        elif l == len(mipmap)-1:
            mask = self.lod_map >= l
        else:
            mask = torch.logical_and(
                self.lod_map >= l, self.lod_map < (l+1))

        if l == len(mipmap)-1:
            blended_levels = mipmap[l]
        else:
            blended_levels = (1 - self.lod_fraction) * \
                mipmap[l] + self.lod_fraction*mipmap[l+1]
        mask = mask[None<