Skip to content

odak.learn.perception

odak.learn.perception

Defines a number of different perceptual loss functions, which can be used to optimise images where gaze location is known.

BlurLoss

BlurLoss implements two different blur losses. When blur_source is set to False, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

When blur_source is set to True, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

The interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/blur_loss.py
class BlurLoss():
    """ 

    `BlurLoss` implements two different blur losses. When `blur_source` is set to `False`, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

    When `blur_source` is set to `True`, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

    The interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device("cpu"),
                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
        """
        Parameters
        ----------

        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        blur_source             : bool
                                    If true, blurs the source image as well as the target before computing the loss.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.device = device
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.mode = mode
        self.blur = None
        self.loss_func = torch.nn.MSELoss()
        self.blur_source = blur_source
        self.equi = equi

    def blur_image(self, image, gaze):
        if self.blur is None:
            self.blur = RadiallyVaryingBlur()
        return self.blur.blur(image, self.alpha, self.real_image_width, self.real_viewing_distance, gaze, self.mode, self.equi)

    def __call__(self, image, target, gaze=[0.5, 0.5]):
        """ 
        Calculates the Blur Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        gaze                : list
                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("BlurLoss", image, target)
        blurred_target = self.blur_image(target, gaze)
        if self.blur_source:
            blurred_image = self.blur_image(image, gaze)
            return self.loss_func(blurred_image, blurred_target)
        else:
            return self.loss_func(image, blurred_target)

    def to(self, device):
        self.device = device
        return self

__call__(image, target, gaze=[0.5, 0.5])

Calculates the Blur Loss.

Parameters:

  • image –
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target –
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze –
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/blur_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):
    """ 
    Calculates the Blur Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    gaze                : list
                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("BlurLoss", image, target)
    blurred_target = self.blur_image(target, gaze)
    if self.blur_source:
        blurred_image = self.blur_image(image, gaze)
        return self.loss_func(blurred_image, blurred_target)
    else:
        return self.loss_func(image, blurred_target)

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', blur_source=False, equi=False)

Parameters:

  • alpha –
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width –
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance –
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • mode –
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • blur_source –
                        If true, blurs the source image as well as the target before computing the loss.
    
  • equi –
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/blur_loss.py
def __init__(self, device=torch.device("cpu"),
             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
    """
    Parameters
    ----------

    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    blur_source             : bool
                                If true, blurs the source image as well as the target before computing the loss.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.device = device
    self.alpha = alpha
    self.real_image_width = real_image_width
    self.real_viewing_distance = real_viewing_distance
    self.mode = mode
    self.blur = None
    self.loss_func = torch.nn.MSELoss()
    self.blur_source = blur_source
    self.equi = equi

CVVDP

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class CVVDP(nn.Module):
    def __init__(self, device = torch.device('cpu')):
        """
        Initializes the CVVDP model with a specified device.

        Parameters
        ----------
        device   : torch.device
                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
        """
        super(CVVDP, self).__init__()
        try:
            import pycvvdp
            self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)
        except Exception as e:
            logging.warning('ColorVideoVDP is missing, consider installing by running "pip install -U git+https://github.com/gfxdisp/ColorVideoVDP"')
            logging.warning(e)


    def forward(self, predictions, targets, dim_order = 'BCHW'):
        """
        Parameters
        ----------
        predictions   : torch.tensor
                        The predicted images.
        targets    h  : torch.tensor
                        The ground truth images.
        dim_order     : str
                        The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).

        Returns
        -------
        result        : torch.tensor
                        The computed loss if successful, otherwise 0.0.
        """
        try:
            if len(predictions.shape) == 3:
                predictions = predictions.unsqueeze(0)
                targets = targets.unsqueeze(0)
            l_ColorVideoVDP = self.cvvdp.predict(predictions, targets, dim_order = dim_order)[0]
            return l_ColorVideoVDP
        except Exception as e:
            logging.warning('ColorVideoVDP failed to compute.')
            logging.warning(e)
            return torch.tensor(0.0)

__init__(device=torch.device('cpu'))

Initializes the CVVDP model with a specified device.

Parameters:

  • device –
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    
Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self, device = torch.device('cpu')):
    """
    Initializes the CVVDP model with a specified device.

    Parameters
    ----------
    device   : torch.device
                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    """
    super(CVVDP, self).__init__()
    try:
        import pycvvdp
        self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)
    except Exception as e:
        logging.warning('ColorVideoVDP is missing, consider installing by running "pip install -U git+https://github.com/gfxdisp/ColorVideoVDP"')
        logging.warning(e)

forward(predictions, targets, dim_order='BCHW')

Parameters:

  • predictions –
            The predicted images.
    
  • targets –
            The ground truth images.
    
  • dim_order –
            The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).
    

Returns:

  • result ( tensor ) –

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets, dim_order = 'BCHW'):
    """
    Parameters
    ----------
    predictions   : torch.tensor
                    The predicted images.
    targets    h  : torch.tensor
                    The ground truth images.
    dim_order     : str
                    The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).

    Returns
    -------
    result        : torch.tensor
                    The computed loss if successful, otherwise 0.0.
    """
    try:
        if len(predictions.shape) == 3:
            predictions = predictions.unsqueeze(0)
            targets = targets.unsqueeze(0)
        l_ColorVideoVDP = self.cvvdp.predict(predictions, targets, dim_order = dim_order)[0]
        return l_ColorVideoVDP
    except Exception as e:
        logging.warning('ColorVideoVDP failed to compute.')
        logging.warning(e)
        return torch.tensor(0.0)

FVVDP

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class FVVDP(nn.Module):
    def __init__(self, device = torch.device('cpu')):
        """
        Initializes the FVVDP model with a specified device.

        Parameters
        ----------
        device   : torch.device
                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
        """
        super(FVVDP, self).__init__()
        try:
            import pyfvvdp
            self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)
        except Exception as e:
            logging.warning('FovVideoVDP is missing, consider installing by running "pip install pyfvvdp"')
            logging.warning(e)


    def forward(self, predictions, targets, dim_order = 'BCHW'):
        """
        Parameters
        ----------
        predictions   : torch.tensor
                        The predicted images.
        targets       : torch.tensor
                        The ground truth images.
        dim_order     : str
                        The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).

        Returns
        -------
        result        : torch.tensor
                          The computed loss if successful, otherwise 0.0.
        """
        try:
            if len(predictions.shape) == 3:
                predictions = predictions.unsqueeze(0)
                targets = targets.unsqueeze(0)
            l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]
            return l_FovVideoVDP
        except Exception as e:
            logging.warning('FovVideoVDP failed to compute.')
            logging.warning(e)
            return torch.tensor(0.0)

__init__(device=torch.device('cpu'))

Initializes the FVVDP model with a specified device.

Parameters:

  • device –
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    
Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self, device = torch.device('cpu')):
    """
    Initializes the FVVDP model with a specified device.

    Parameters
    ----------
    device   : torch.device
                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    """
    super(FVVDP, self).__init__()
    try:
        import pyfvvdp
        self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)
    except Exception as e:
        logging.warning('FovVideoVDP is missing, consider installing by running "pip install pyfvvdp"')
        logging.warning(e)

forward(predictions, targets, dim_order='BCHW')

Parameters:

  • predictions –
            The predicted images.
    
  • targets –
            The ground truth images.
    
  • dim_order –
            The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).
    

Returns:

  • result ( tensor ) –

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets, dim_order = 'BCHW'):
    """
    Parameters
    ----------
    predictions   : torch.tensor
                    The predicted images.
    targets       : torch.tensor
                    The ground truth images.
    dim_order     : str
                    The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).

    Returns
    -------
    result        : torch.tensor
                      The computed loss if successful, otherwise 0.0.
    """
    try:
        if len(predictions.shape) == 3:
            predictions = predictions.unsqueeze(0)
            targets = targets.unsqueeze(0)
        l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]
        return l_FovVideoVDP
    except Exception as e:
        logging.warning('FovVideoVDP failed to compute.')
        logging.warning(e)
        return torch.tensor(0.0)

LPIPS

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class LPIPS(nn.Module):

    def __init__(self):
        """
        Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.

        """
        super(LPIPS, self).__init__()
        try:
            import torchmetrics
            self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')
        except Exception as e:
            logging.warning('torchmetrics is missing, consider installing by running "pip install torchmetrics"')
            logging.warning(e)


    def forward(self, predictions, targets):
        """
        Parameters
        ----------
        predictions   : torch.tensor
                        The predicted images.
        targets       : torch.tensor
                        The ground truth images.

        Returns
        -------
        result        : torch.tensor
                        The computed loss if successful, otherwise 0.0.
        """
        try:
            if len(predictions.shape) == 3:
                predictions = predictions.unsqueeze(0)
                targets = targets.unsqueeze(0)
            lpips_image = predictions
            lpips_target = targets
            if len(lpips_image.shape) == 3:
                lpips_image = lpips_image.unsqueeze(0)
                lpips_target = lpips_target.unsqueeze(0)
            if lpips_image.shape[1] == 1:
                lpips_image = lpips_image.repeat(1, 3, 1, 1)
                lpips_target = lpips_target.repeat(1, 3, 1, 1)
            lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)
            lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)
            l_LPIPS = self.lpips(lpips_image, lpips_target)
            return l_LPIPS
        except Exception as e:
            logging.warning('LPIPS failed to compute.')
            logging.warning(e)
            return torch.tensor(0.0)

__init__()

Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.

Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self):
    """
    Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.

    """
    super(LPIPS, self).__init__()
    try:
        import torchmetrics
        self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')
    except Exception as e:
        logging.warning('torchmetrics is missing, consider installing by running "pip install torchmetrics"')
        logging.warning(e)

forward(predictions, targets)

Parameters:

  • predictions –
            The predicted images.
    
  • targets –
            The ground truth images.
    

Returns:

  • result ( tensor ) –

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets):
    """
    Parameters
    ----------
    predictions   : torch.tensor
                    The predicted images.
    targets       : torch.tensor
                    The ground truth images.

    Returns
    -------
    result        : torch.tensor
                    The computed loss if successful, otherwise 0.0.
    """
    try:
        if len(predictions.shape) == 3:
            predictions = predictions.unsqueeze(0)
            targets = targets.unsqueeze(0)
        lpips_image = predictions
        lpips_target = targets
        if len(lpips_image.shape) == 3:
            lpips_image = lpips_image.unsqueeze(0)
            lpips_target = lpips_target.unsqueeze(0)
        if lpips_image.shape[1] == 1:
            lpips_image = lpips_image.repeat(1, 3, 1, 1)
            lpips_target = lpips_target.repeat(1, 3, 1, 1)
        lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)
        lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)
        l_LPIPS = self.lpips(lpips_image, lpips_target)
        return l_LPIPS
    except Exception as e:
        logging.warning('LPIPS failed to compute.')
        logging.warning(e)
        return torch.tensor(0.0)

MSSSIM

Bases: Module

A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class MSSSIM(nn.Module):
    '''
    A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.
    '''

    def __init__(self):
        super(MSSSIM, self).__init__()

    def forward(self, predictions, targets):
        """
        Parameters
        ----------
        predictions : torch.tensor
                      The predicted images.
        targets     : torch.tensor
                      The ground truth images.

        Returns
        -------
        result      : torch.tensor 
                      The computed MS-SSIM value if successful, otherwise 0.0.
        """
        try:
            from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
            if len(predictions.shape) == 3:
                predictions = predictions.unsqueeze(0)
                targets = targets.unsqueeze(0)
            l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)
            return l_MSSSIM  
        except Exception as e:
            logging.warning('MS-SSIM failed to compute.')
            logging.warning(e)
            return torch.tensor(0.0)

forward(predictions, targets)

Parameters:

  • predictions (tensor) –
          The predicted images.
    
  • targets –
          The ground truth images.
    

Returns:

  • result ( tensor ) –

    The computed MS-SSIM value if successful, otherwise 0.0.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets):
    """
    Parameters
    ----------
    predictions : torch.tensor
                  The predicted images.
    targets     : torch.tensor
                  The ground truth images.

    Returns
    -------
    result      : torch.tensor 
                  The computed MS-SSIM value if successful, otherwise 0.0.
    """
    try:
        from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
        if len(predictions.shape) == 3:
            predictions = predictions.unsqueeze(0)
            targets = targets.unsqueeze(0)
        l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)
        return l_MSSSIM  
    except Exception as e:
        logging.warning('MS-SSIM failed to compute.')
        logging.warning(e)
        return torch.tensor(0.0)

MetamerMSELoss

The MetamerMSELoss class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

Please note this is different to MetamericLoss which optimises the source image to be any metamer of the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metamer_mse_loss.py
class MetamerMSELoss():
    """ 
    The `MetamerMSELoss` class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

    Please note this is different to `MetamericLoss` which optimises the source image to be any metamer of the target image.

    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device("cpu"),
                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
                 n_pyramid_levels=5, n_orientations=2, equi=False):
        """
        Parameters
        ----------
        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        n_pyramid_levels        : int 
                                    Number of levels of the steerable pyramid. Note that the image is padded
                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                    too high will slow down the calculation a lot.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        n_orientations          : int 
                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                    Increasing this will increase runtime.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.target_metamer = None
        self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
                                            real_viewing_distance=real_viewing_distance,
                                            n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
        self.loss_func = torch.nn.MSELoss()
        self.noise = None

    def gen_metamer(self, image, gaze):
        """ 
        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
        This function can be used on its own to generate a metamer for a desired image.

        Parameters
        ----------
        image   : torch.tensor
                Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
        gaze    : list
                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        metamer : torch.tensor
                The generated metamer image
        """
        image = rgb_2_ycrcb(image)
        image_size = image.size()
        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)

        target_stats = self.metameric_loss.calc_statsmaps(
            image, gaze=gaze, alpha=self.metameric_loss.alpha)
        target_means = target_stats[::2]
        target_stdevs = target_stats[1::2]
        if self.noise is None or self.noise.size() != image.size():
            torch.manual_seed(0)
            noise_image = torch.rand_like(image)
        noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
            noise_image, self.metameric_loss.n_pyramid_levels)
        input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
            image, self.metameric_loss.n_pyramid_levels)

        def match_level(input_level, target_mean, target_std):
            level = input_level.clone()
            level -= torch.mean(level)
            input_std = torch.sqrt(torch.mean(level * level))
            eps = 1e-6
            # Safeguard against divide by zero
            input_std[input_std < eps] = eps
            level /= input_std
            level *= target_std
            level += target_mean
            return level

        nbands = len(noise_pyramid[0]["b"])
        noise_pyramid[0]["h"] = match_level(
            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
        for l in range(len(noise_pyramid)-1):
            for b in range(nbands):
                noise_pyramid[l]["b"][b] = match_level(
                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

        metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
            noise_pyramid)
        metamer = ycrcb_2_rgb(metamer)
        # Crop to remove any padding
        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
        return metamer

    def __call__(self, image, target, gaze=[0.5, 0.5]):
        """ 
        Calculates the Metamer MSE Loss.

        Parameters
        ----------
        image   : torch.tensor
                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target  : torch.tensor
                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        gaze    : list
                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamerMSELoss", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)

        if target is not self.target or self.target is None:
            self.target_metamer = self.gen_metamer(target, gaze)
            self.target = target

        return self.loss_func(image, self.target_metamer)

    def to(self, device):
        self.metameric_loss = self.metameric_loss.to(device)
        return self

__call__(image, target, gaze=[0.5, 0.5])

Calculates the Metamer MSE Loss.

Parameters:

  • image –
    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target –
    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze –
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metamer_mse_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):
    """ 
    Calculates the Metamer MSE Loss.

    Parameters
    ----------
    image   : torch.tensor
            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target  : torch.tensor
            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    gaze    : list
            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamerMSELoss", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)

    if target is not self.target or self.target is None:
        self.target_metamer = self.gen_metamer(target, gaze)
        self.target = target

    return self.loss_func(image, self.target_metamer)

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', n_pyramid_levels=5, n_orientations=2, equi=False)

Parameters:

  • alpha –
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width –
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance –
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • n_pyramid_levels –
                        Number of levels of the steerable pyramid. Note that the image is padded
                        so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                        too high will slow down the calculation a lot.
    
  • mode –
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • n_orientations –
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                        Increasing this will increase runtime.
    
  • equi –
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/metamer_mse_loss.py
def __init__(self, device=torch.device("cpu"),
             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
             n_pyramid_levels=5, n_orientations=2, equi=False):
    """
    Parameters
    ----------
    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    n_pyramid_levels        : int 
                                Number of levels of the steerable pyramid. Note that the image is padded
                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                too high will slow down the calculation a lot.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    n_orientations          : int 
                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                Increasing this will increase runtime.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.target_metamer = None
    self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
                                        real_viewing_distance=real_viewing_distance,
                                        n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
    self.loss_func = torch.nn.MSELoss()
    self.noise = None

gen_metamer(image, gaze)

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image –
    Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze –
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • metamer ( tensor ) –

    The generated metamer image

Source code in odak/learn/perception/metamer_mse_loss.py
def gen_metamer(self, image, gaze):
    """ 
    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
    This function can be used on its own to generate a metamer for a desired image.

    Parameters
    ----------
    image   : torch.tensor
            Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    gaze    : list
            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    metamer : torch.tensor
            The generated metamer image
    """
    image = rgb_2_ycrcb(image)
    image_size = image.size()
    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)

    target_stats = self.metameric_loss.calc_statsmaps(
        image, gaze=gaze, alpha=self.metameric_loss.alpha)
    target_means = target_stats[::2]
    target_stdevs = target_stats[1::2]
    if self.noise is None or self.noise.size() != image.size():
        torch.manual_seed(0)
        noise_image = torch.rand_like(image)
    noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
        noise_image, self.metameric_loss.n_pyramid_levels)
    input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
        image, self.metameric_loss.n_pyramid_levels)

    def match_level(input_level, target_mean, target_std):
        level = input_level.clone()
        level -= torch.mean(level)
        input_std = torch.sqrt(torch.mean(level * level))
        eps = 1e-6
        # Safeguard against divide by zero
        input_std[input_std < eps] = eps
        level /= input_std
        level *= target_std
        level += target_mean
        return level

    nbands = len(noise_pyramid[0]["b"])
    noise_pyramid[0]["h"] = match_level(
        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
    for l in range(len(noise_pyramid)-1):
        for b in range(nbands):
            noise_pyramid[l]["b"][b] = match_level(
                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

    metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
        noise_pyramid)
    metamer = ycrcb_2_rgb(metamer)
    # Crop to remove any padding
    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
    return metamer

MetamericLoss

The MetamericLoss class provides a perceptual loss function.

Rather than exactly match the source image to the target, it tries to ensure the source is a metamer to the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metameric_loss.py
class MetamericLoss():
    """
    The `MetamericLoss` class provides a perceptual loss function.

    Rather than exactly match the source image to the target, it tries to ensure the source is a *metamer* to the target image.

    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
                 real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
                 n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
                 use_fullres_l0=False, equi=False):
        """
        Parameters
        ----------

        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        n_pyramid_levels        : int 
                                    Number of levels of the steerable pyramid. Note that the image is padded
                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                    too high will slow down the calculation a lot.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        n_orientations          : int 
                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                    Increasing this will increase runtime.
        use_l2_foveal_loss      : bool 
                                    If true, for all the pixels that have pooling size 1 pixel in the 
                                    largest scale will use direct L2 against target rather than pooling over pyramid levels.
                                    In practice this gives better results when the loss is used for holography.
        fovea_weight            : float 
                                    A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
        use_radial_weight       : bool 
                                    If True, will apply a radial weighting when calculating the difference between
                                    the source and target stats maps. This weights stats closer to the fovea more than those
                                    further away.
        use_fullres_l0          : bool 
                                    If true, stats for the lowpass residual are replaced with blurred versions
                                    of the full-resolution source and target images.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.device = device
        self.pyramid_maker = None
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.blurs = None
        self.n_pyramid_levels = n_pyramid_levels
        self.n_orientations = n_orientations
        self.mode = mode
        self.use_l2_foveal_loss = use_l2_foveal_loss
        self.fovea_weight = fovea_weight
        self.use_radial_weight = use_radial_weight
        self.use_fullres_l0 = use_fullres_l0
        self.equi = equi
        if self.use_fullres_l0 and self.use_l2_foveal_loss:
            raise Exception(
                "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")

    def calc_statsmaps(self, image, gaze=None, alpha=0.01, real_image_width=0.3,
                       real_viewing_distance=0.6, mode="quadratic", equi=False):

        if self.pyramid_maker is None or \
                self.pyramid_maker.device != self.device or \
                len(self.pyramid_maker.band_filters) != self.n_orientations or\
                self.pyramid_maker.filt_h0.size(0) != image.size(1):
            self.pyramid_maker = SpatialSteerablePyramid(
                use_bilinear_downup=False, n_channels=image.size(1),
                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)

        if self.blurs is None or len(self.blurs) != self.n_pyramid_levels:
            self.blurs = [RadiallyVaryingBlur()
                          for i in range(self.n_pyramid_levels)]

        def find_stats(image_pyr_level, blur):
            image_means = blur.blur(
                image_pyr_level, alpha, real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)
            image_meansq = blur.blur(image_pyr_level*image_pyr_level, alpha,
                                     real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)

            image_vars = image_meansq - (image_means*image_means)
            image_vars[image_vars < 1e-7] = 1e-7
            image_std = torch.sqrt(image_vars)
            if torch.any(torch.isnan(image_means)):
                print(image_means)
                raise Exception("NaN in image means!")
            if torch.any(torch.isnan(image_std)):
                print(image_std)
                raise Exception("NaN in image stdevs!")
            if self.use_fullres_l0:
                mask = blur.lod_map > 1e-6
                mask = mask[None, None, ...]
                if image_means.size(1) > 1:
                    mask = mask.repeat(1, image_means.size(1), 1, 1)
                matte = torch.zeros_like(image_means)
                matte[mask] = 1.0
                return image_means * matte, image_std * matte
            return image_means, image_std
        output_stats = []
        image_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)
        means, variances = find_stats(image_pyramid[0]['h'], self.blurs[0])
        if self.use_l2_foveal_loss:
            self.fovea_mask = torch.zeros(image.size(), device=image.device)
            for i in range(self.fovea_mask.size(1)):
                self.fovea_mask[0, i, ...] = 1.0 - \
                    (self.blurs[0].lod_map / torch.max(self.blurs[0].lod_map))
                self.fovea_mask[0, i, self.blurs[0].lod_map < 1e-6] = 1.0
            self.fovea_mask = torch.pow(self.fovea_mask, 10.0)
            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, scale_factor=0.125, mode="area")
            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, size=(image.size(-2), image.size(-1)), mode="bilinear")
            periphery_mask = 1.0 - self.fovea_mask
            self.periphery_mask = periphery_mask.clone()
            output_stats.append(means * periphery_mask)
            output_stats.append(variances * periphery_mask)
        else:
            output_stats.append(means)
            output_stats.append(variances)

        for l in range(0, len(image_pyramid)-1):
            for o in range(len(image_pyramid[l]['b'])):
                means, variances = find_stats(
                    image_pyramid[l]['b'][o], self.blurs[l])
                if self.use_l2_foveal_loss:
                    output_stats.append(means * periphery_mask)
                    output_stats.append(variances * periphery_mask)
                else:
                    output_stats.append(means)
                    output_stats.append(variances)
            if self.use_l2_foveal_loss:
                periphery_mask = torch.nn.functional.interpolate(
                    periphery_mask, scale_factor=0.5, mode="area", recompute_scale_factor=False)

        if self.use_l2_foveal_loss:
            output_stats.append(image_pyramid[-1]["l"] * periphery_mask)
        elif self.use_fullres_l0:
            output_stats.append(self.blurs[0].blur(
                image, alpha, real_image_width, real_viewing_distance, gaze, mode))
        else:
            output_stats.append(image_pyramid[-1]["l"])
        return output_stats

    def metameric_loss_stats(self, statsmap_a, statsmap_b, gaze):
        loss = 0.0
        for a, b in zip(statsmap_a, statsmap_b):
            if self.use_radial_weight:
                radii = make_radial_map(
                    [a.size(-2), a.size(-1)], gaze).to(a.device)
                weights = 1.1 - (radii * radii * radii * radii)
                weights = weights[None, None, ...].repeat(1, a.size(1), 1, 1)
                loss += torch.nn.MSELoss()(weights*a, weights*b)
            else:
                loss += torch.nn.MSELoss()(a, b)
        loss /= len(statsmap_a)
        return loss

    def visualise_loss_map(self, image_stats):
        loss_map = torch.zeros(image_stats[0].size()[-2:])
        for i in range(len(image_stats)):
            stats = image_stats[i]
            target_stats = self.target_stats[i]
            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
            loss_map += stat_mse_map[0, 0, ...]
        self.loss_map = loss_map

    def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
        """ 
        Calculates the Metameric Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        image_colorspace    : str
                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
                                accepted values: RGB, YCrCb.
        gaze                : list
                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
        visualise_loss      : bool
                                Shows a heatmap indicating which parts of the image contributed most to the loss. 

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamericLoss", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
        # If input is RGB, convert to YCrCb.
        if image.size(1) == 3 and image_colorspace == "RGB":
            image = rgb_2_ycrcb(image)
            target = rgb_2_ycrcb(target)
        if self.target is None:
            self.target = torch.zeros(target.shape).to(target.device)
        if type(target) == type(self.target):
            if not torch.all(torch.eq(target, self.target)):
                self.target = target.detach().clone()
                self.target_stats = self.calc_statsmaps(
                    self.target,
                    gaze=gaze,
                    alpha=self.alpha,
                    real_image_width=self.real_image_width,
                    real_viewing_distance=self.real_viewing_distance,
                    mode=self.mode
                )
                self.target = target.detach().clone()
            image_stats = self.calc_statsmaps(
                image,
                gaze=gaze,
                alpha=self.alpha,
                real_image_width=self.real_image_width,
                real_viewing_distance=self.real_viewing_distance,
                mode=self.mode
            )
            if visualise_loss:
                self.visualise_loss_map(image_stats)
            if self.use_l2_foveal_loss:
                peripheral_loss = self.metameric_loss_stats(
                    image_stats, self.target_stats, gaze)
                foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
                # New weighting - evenly weight fovea and periphery.
                loss = peripheral_loss + self.fovea_weight * foveal_loss
            else:
                loss = self.metameric_loss_stats(
                    image_stats, self.target_stats, gaze)
            return loss
        else:
            raise Exception("Target of incorrect type")

    def to(self, device):
        self.device = device
        return self

__call__(image, target, gaze=[0.5, 0.5], image_colorspace='RGB', visualise_loss=False)

Calculates the Metameric Loss.

Parameters:

  • image –
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target –
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • image_colorspace –
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
                    accepted values: RGB, YCrCb.
    
  • gaze –
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    
  • visualise_loss –
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metameric_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
    """ 
    Calculates the Metameric Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    image_colorspace    : str
                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
                            accepted values: RGB, YCrCb.
    gaze                : list
                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    visualise_loss      : bool
                            Shows a heatmap indicating which parts of the image contributed most to the loss. 

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamericLoss", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
    # If input is RGB, convert to YCrCb.
    if image.size(1) == 3 and image_colorspace == "RGB":
        image = rgb_2_ycrcb(image)
        target = rgb_2_ycrcb(target)
    if self.target is None:
        self.target = torch.zeros(target.shape).to(target.device)
    if type(target) == type(self.target):
        if not torch.all(torch.eq(target, self.target)):
            self.target = target.detach().clone()
            self.target_stats = self.calc_statsmaps(
                self.target,
                gaze=gaze,
                alpha=self.alpha,
                real_image_width=self.real_image_width,
                real_viewing_distance=self.real_viewing_distance,
                mode=self.mode
            )
            self.target = target.detach().clone()
        image_stats = self.calc_statsmaps(
            image,
            gaze=gaze,
            alpha=self.alpha,
            real_image_width=self.real_image_width,
            real_viewing_distance=self.real_viewing_distance,
            mode=self.mode
        )
        if visualise_loss:
            self.visualise_loss_map(image_stats)
        if self.use_l2_foveal_loss:
            peripheral_loss = self.metameric_loss_stats(
                image_stats, self.target_stats, gaze)
            foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
            # New weighting - evenly weight fovea and periphery.
            loss = peripheral_loss + self.fovea_weight * foveal_loss
        else:
            loss = self.metameric_loss_stats(
                image_stats, self.target_stats, gaze)
        return loss
    else:
        raise Exception("Target of incorrect type")

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, n_pyramid_levels=5, mode='quadratic', n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False, use_fullres_l0=False, equi=False)

Parameters:

  • alpha –
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width –
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance –
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • n_pyramid_levels –
                        Number of levels of the steerable pyramid. Note that the image is padded
                        so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                        too high will slow down the calculation a lot.
    
  • mode –
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • n_orientations –
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                        Increasing this will increase runtime.
    
  • use_l2_foveal_loss –
                        If true, for all the pixels that have pooling size 1 pixel in the 
                        largest scale will use direct L2 against target rather than pooling over pyramid levels.
                        In practice this gives better results when the loss is used for holography.
    
  • fovea_weight –
                        A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    
  • use_radial_weight –
                        If True, will apply a radial weighting when calculating the difference between
                        the source and target stats maps. This weights stats closer to the fovea more than those
                        further away.
    
  • use_fullres_l0 –
                        If true, stats for the lowpass residual are replaced with blurred versions
                        of the full-resolution source and target images.
    
  • equi –
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/metameric_loss.py
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
             real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
             n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
             use_fullres_l0=False, equi=False):
    """
    Parameters
    ----------

    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    n_pyramid_levels        : int 
                                Number of levels of the steerable pyramid. Note that the image is padded
                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                too high will slow down the calculation a lot.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    n_orientations          : int 
                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                Increasing this will increase runtime.
    use_l2_foveal_loss      : bool 
                                If true, for all the pixels that have pooling size 1 pixel in the 
                                largest scale will use direct L2 against target rather than pooling over pyramid levels.
                                In practice this gives better results when the loss is used for holography.
    fovea_weight            : float 
                                A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    use_radial_weight       : bool 
                                If True, will apply a radial weighting when calculating the difference between
                                the source and target stats maps. This weights stats closer to the fovea more than those
                                further away.
    use_fullres_l0          : bool 
                                If true, stats for the lowpass residual are replaced with blurred versions
                                of the full-resolution source and target images.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.device = device
    self.pyramid_maker = None
    self.alpha = alpha
    self.real_image_width = real_image_width
    self.real_viewing_distance = real_viewing_distance
    self.blurs = None
    self.n_pyramid_levels = n_pyramid_levels
    self.n_orientations = n_orientations
    self.mode = mode
    self.use_l2_foveal_loss = use_l2_foveal_loss
    self.fovea_weight = fovea_weight
    self.use_radial_weight = use_radial_weight
    self.use_fullres_l0 = use_fullres_l0
    self.equi = equi
    if self.use_fullres_l0 and self.use_l2_foveal_loss:
        raise Exception(
            "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")

MetamericLossUniform

Measures metameric loss between a given image and a metamer of the given target image. This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.

Source code in odak/learn/perception/metameric_loss_uniform.py
class MetamericLossUniform():
    """
    Measures metameric loss between a given image and a metamer of the given target image.
    This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.
    """

    def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
        """

        Parameters
        ----------
        pooling_size            : int
                                  Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
        n_pyramid_levels        : int 
                                  Number of levels of the steerable pyramid. Note that the image is padded
                                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                  too high will slow down the calculation a lot.
        n_orientations          : int 
                                  Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                  Increasing this will increase runtime.

        """
        self.target = None
        self.device = device
        self.pyramid_maker = None
        self.pooling_size = pooling_size
        self.n_pyramid_levels = n_pyramid_levels
        self.n_orientations = n_orientations

    def calc_statsmaps(self, image, pooling_size):

        if self.pyramid_maker is None or \
                self.pyramid_maker.device != self.device or \
                len(self.pyramid_maker.band_filters) != self.n_orientations or\
                self.pyramid_maker.filt_h0.size(0) != image.size(1):
            self.pyramid_maker = SpatialSteerablePyramid(
                use_bilinear_downup=False, n_channels=image.size(1),
                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)


        def find_stats(image_pyr_level, pooling_size):
            image_means = uniform_blur(image_pyr_level, pooling_size)
            image_meansq = uniform_blur(image_pyr_level*image_pyr_level, pooling_size)
            image_vars = image_meansq - (image_means*image_means)
            image_vars[image_vars < 1e-7] = 1e-7
            image_std = torch.sqrt(image_vars)
            if torch.any(torch.isnan(image_means)):
                print(image_means)
                raise Exception("NaN in image means!")
            if torch.any(torch.isnan(image_std)):
                print(image_std)
                raise Exception("NaN in image stdevs!")
            return image_means, image_std

        output_stats = []
        image_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)
        curr_pooling_size = pooling_size
        means, variances = find_stats(image_pyramid[0]['h'], curr_pooling_size)
        output_stats.append(means)
        output_stats.append(variances)

        for l in range(0, len(image_pyramid)-1):
            for o in range(len(image_pyramid[l]['b'])):
                means, variances = find_stats(
                    image_pyramid[l]['b'][o], curr_pooling_size)
                output_stats.append(means)
                output_stats.append(variances)
            curr_pooling_size /= 2

        output_stats.append(image_pyramid[-1]["l"])
        return output_stats

    def metameric_loss_stats(self, statsmap_a, statsmap_b):
        loss = 0.0
        for a, b in zip(statsmap_a, statsmap_b):
            loss += torch.nn.MSELoss()(a, b)
        loss /= len(statsmap_a)
        return loss

    def visualise_loss_map(self, image_stats):
        loss_map = torch.zeros(image_stats[0].size()[-2:])
        for i in range(len(image_stats)):
            stats = image_stats[i]
            target_stats = self.target_stats[i]
            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
            loss_map += stat_mse_map[0, 0, ...]
        self.loss_map = loss_map

    def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
        """ 
        Calculates the Metameric Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        image_colorspace    : str
                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
                                accepted values: RGB, YCrCb.
        visualise_loss      : bool
                                Shows a heatmap indicating which parts of the image contributed most to the loss. 

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamericLossUniform", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
        # If input is RGB, convert to YCrCb.
        if image.size(1) == 3 and image_colorspace == "RGB":
            image = rgb_2_ycrcb(image)
            target = rgb_2_ycrcb(target)
        if self.target is None:
            self.target = torch.zeros(target.shape).to(target.device)
        if type(target) == type(self.target):
            if not torch.all(torch.eq(target, self.target)):
                self.target = target.detach().clone()
                self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
                self.target = target.detach().clone()
            image_stats = self.calc_statsmaps(image, self.pooling_size)

            if visualise_loss:
                self.visualise_loss_map(image_stats)
            loss = self.metameric_loss_stats(
                image_stats, self.target_stats)
            return loss
        else:
            raise Exception("Target of incorrect type")

    def gen_metamer(self, image):
        """ 
        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
        This function can be used on its own to generate a metamer for a desired image.

        Parameters
        ----------
        image   : torch.tensor
                  Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)

        Returns
        -------
        metamer : torch.tensor
                  The generated metamer image
        """
        image = rgb_2_ycrcb(image)
        image_size = image.size()
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)

        target_stats = self.calc_statsmaps(
            image, self.pooling_size)
        target_means = target_stats[::2]
        target_stdevs = target_stats[1::2]
        torch.manual_seed(0)
        noise_image = torch.rand_like(image)
        noise_pyramid = self.pyramid_maker.construct_pyramid(
            noise_image, self.n_pyramid_levels)
        input_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)

        def match_level(input_level, target_mean, target_std):
            level = input_level.clone()
            level -= torch.mean(level)
            input_std = torch.sqrt(torch.mean(level * level))
            eps = 1e-6
            # Safeguard against divide by zero
            input_std[input_std < eps] = eps
            level /= input_std
            level *= target_std
            level += target_mean
            return level

        nbands = len(noise_pyramid[0]["b"])
        noise_pyramid[0]["h"] = match_level(
            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
        for l in range(len(noise_pyramid)-1):
            for b in range(nbands):
                noise_pyramid[l]["b"][b] = match_level(
                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

        metamer = self.pyramid_maker.reconstruct_from_pyramid(
            noise_pyramid)
        metamer = ycrcb_2_rgb(metamer)
        # Crop to remove any padding
        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
        return metamer

    def to(self, device):
        self.device = device
        return self

__call__(image, target, image_colorspace='RGB', visualise_loss=False)

Calculates the Metameric Loss.

Parameters:

  • image –
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target –
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • image_colorspace –
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
                    accepted values: RGB, YCrCb.
    
  • visualise_loss –
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metameric_loss_uniform.py
def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
    """ 
    Calculates the Metameric Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    image_colorspace    : str
                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
                            accepted values: RGB, YCrCb.
    visualise_loss      : bool
                            Shows a heatmap indicating which parts of the image contributed most to the loss. 

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamericLossUniform", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
    # If input is RGB, convert to YCrCb.
    if image.size(1) == 3 and image_colorspace == "RGB":
        image = rgb_2_ycrcb(image)
        target = rgb_2_ycrcb(target)
    if self.target is None:
        self.target = torch.zeros(target.shape).to(target.device)
    if type(target) == type(self.target):
        if not torch.all(torch.eq(target, self.target)):
            self.target = target.detach().clone()
            self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
            self.target = target.detach().clone()
        image_stats = self.calc_statsmaps(image, self.pooling_size)

        if visualise_loss:
            self.visualise_loss_map(image_stats)
        loss = self.metameric_loss_stats(
            image_stats, self.target_stats)
        return loss
    else:
        raise Exception("Target of incorrect type")

__init__(device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2)

Parameters:

  • pooling_size –
                      Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    
  • n_pyramid_levels –
                      Number of levels of the steerable pyramid. Note that the image is padded
                      so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                      too high will slow down the calculation a lot.
    
  • n_orientations –
                      Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                      Increasing this will increase runtime.
    
Source code in odak/learn/perception/metameric_loss_uniform.py
def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
    """

    Parameters
    ----------
    pooling_size            : int
                              Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    n_pyramid_levels        : int 
                              Number of levels of the steerable pyramid. Note that the image is padded
                              so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                              too high will slow down the calculation a lot.
    n_orientations          : int 
                              Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                              Increasing this will increase runtime.

    """
    self.target = None
    self.device = device
    self.pyramid_maker = None
    self.pooling_size = pooling_size
    self.n_pyramid_levels = n_pyramid_levels
    self.n_orientations = n_orientations

gen_metamer(image)

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image –
      Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    

Returns:

  • metamer ( tensor ) –

    The generated metamer image

Source code in odak/learn/perception/metameric_loss_uniform.py
def gen_metamer(self, image):
    """ 
    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
    This function can be used on its own to generate a metamer for a desired image.

    Parameters
    ----------
    image   : torch.tensor
              Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)

    Returns
    -------
    metamer : torch.tensor
              The generated metamer image
    """
    image = rgb_2_ycrcb(image)
    image_size = image.size()
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)

    target_stats = self.calc_statsmaps(
        image, self.pooling_size)
    target_means = target_stats[::2]
    target_stdevs = target_stats[1::2]
    torch.manual_seed(0)
    noise_image = torch.rand_like(image)
    noise_pyramid = self.pyramid_maker.construct_pyramid(
        noise_image, self.n_pyramid_levels)
    input_pyramid = self.pyramid_maker.construct_pyramid(
        image, self.n_pyramid_levels)

    def match_level(input_level, target_mean, target_std):
        level = input_level.clone()
        level -= torch.mean(level)
        input_std = torch.sqrt(torch.mean(level * level))
        eps = 1e-6
        # Safeguard against divide by zero
        input_std[input_std < eps] = eps
        level /= input_std
        level *= target_std
        level += target_mean
        return level

    nbands = len(noise_pyramid[0]["b"])
    noise_pyramid[0]["h"] = match_level(
        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
    for l in range(len(noise_pyramid)-1):
        for b in range(nbands):
            noise_pyramid[l]["b"][b] = match_level(
                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

    metamer = self.pyramid_maker.reconstruct_from_pyramid(
        noise_pyramid)
    metamer = ycrcb_2_rgb(metamer)
    # Crop to remove any padding
    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
    return metamer

PSNR

Bases: Module

A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class PSNR(nn.Module):
    '''
    A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.
    '''

    def __init__(self):
        super(PSNR, self).__init__()

    def forward(self, predictions, targets, peak_value = 1.0):
        """
        A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

        Parameters
        ----------
        predictions   : torch.tensor
                        Image to be tested.
        targets       : torch.tensor
                        Ground truth image.
        peak_value    : float
                        Peak value that given tensors could have.

        Returns
        -------
        result        : torch.tensor
                        Peak-signal-to-noise ratio.
        """
        mse = torch.mean((targets - predictions) ** 2)
        result = 20 * torch.log10(peak_value / torch.sqrt(mse))
        return result

forward(predictions, targets, peak_value=1.0)

A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

Parameters:

  • predictions –
            Image to be tested.
    
  • targets –
            Ground truth image.
    
  • peak_value –
            Peak value that given tensors could have.
    

Returns:

  • result ( tensor ) –

    Peak-signal-to-noise ratio.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets, peak_value = 1.0):
    """
    A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

    Parameters
    ----------
    predictions   : torch.tensor
                    Image to be tested.
    targets       : torch.tensor
                    Ground truth image.
    peak_value    : float
                    Peak value that given tensors could have.

    Returns
    -------
    result        : torch.tensor
                    Peak-signal-to-noise ratio.
    """
    mse = torch.mean((targets - predictions) ** 2)
    result = 20 * torch.log10(peak_value / torch.sqrt(mse))
    return result

RadiallyVaryingBlur

The RadiallyVaryingBlur class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given alpha parameter value. For more information on how the pooling sizes are computed, please see link coming soon.

The blur is accelerated by generating and sampling from MIP maps of the input image.

This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

Source code in odak/learn/perception/radially_varying_blur.py
class RadiallyVaryingBlur():
    """ 

    The `RadiallyVaryingBlur` class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given `alpha` parameter value. For more information on how the pooling sizes are computed, please see [link coming soon]().

    The blur is accelerated by generating and sampling from MIP maps of the input image.

    This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

    If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

    """

    def __init__(self):
        self.lod_map = None
        self.equi = None

    def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
        """
        Apply the radially varying blur to an image.

        Parameters
        ----------

        image                   : torch.tensor
                                    The image to blur, in NCHW format.
        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
                                    Ignored in equirectangular mode (equi==True)
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
                                    Ignored in equirectangular mode (equi==True)
        centre                  : tuple of floats
                                    The centre of the radially varying blur (the gaze location).
                                    Should be a tuple of floats containing normalised image coordinates in range [0,1]
                                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        equi                    : bool
                                    If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The centre argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]

        Returns
        -------

        output                  : torch.tensor
                                    The blurred image
        """
        size = (image.size(-2), image.size(-1))

        # LOD map caching
        if self.lod_map is None or\
                self.size != size or\
                self.n_channels != image.size(1) or\
                self.alpha != alpha or\
                self.real_image_width != real_image_width or\
                self.real_viewing_distance != real_viewing_distance or\
                self.centre != centre or\
                self.mode != mode or\
                self.equi != equi:
            if not equi:
                self.lod_map = make_pooling_size_map_lod(
                    centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
            else:
                self.lod_map = make_equi_pooling_size_map_lod(
                    centre, (image.size(-2), image.size(-1)), alpha, mode)
            self.size = size
            self.n_channels = image.size(1)
            self.alpha = alpha
            self.real_image_width = real_image_width
            self.real_viewing_distance = real_viewing_distance
            self.centre = centre
            self.lod_map = self.lod_map.to(image.device)
            self.lod_fraction = torch.fmod(self.lod_map, 1.0)
            self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
                1, image.size(1), 1, 1)
            self.mode = mode
            self.equi = equi

        if self.lod_map.device != image.device:
            self.lod_map = self.lod_map.to(image.device)
        if self.lod_fraction.device != image.device:
            self.lod_fraction = self.lod_fraction.to(image.device)

        mipmap = [image]
        while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
            mipmap.append(torch.nn.functional.interpolate(
                mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
        if mipmap[-1].size(-1) == 2:
            final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
            mipmap.append(final_mip)
        if mipmap[-1].size(-2) == 2:
            final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
            mipmap.append(final_mip)

        for l in range(len(mipmap)):
            if l == len(mipmap)-1:
                mipmap[l] = mipmap[l] * \
                    torch.ones(image.size(), device=image.device)
            else:
                for l2 in range(l-1, -1, -1):
                    mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
                        image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)

        output = torch.zeros(image.size(), device=image.device)
        for l in range(len(mipmap)):
            if l == 0:
                mask = self.lod_map < (l+1)
            elif l == len(mipmap)-1:
                mask = self.lod_map >= l
            else:
                mask = torch.logical_and(
                    self.lod_map >= l, self.lod_map < (l+1))

            if l == len(mipmap)-1:
                blended_levels = mipmap[l]
            else:
                blended_levels = (1 - self.lod_fraction) * \
                    mipmap[l] + self.lod_fraction*mipmap[l+1]
            mask = mask[None, None, ...]
            mask = mask.repeat(1, image.size(1), 1, 1)
            output[mask] = blended_levels[mask]

        return output

blur(image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode='quadratic', equi=False)

Apply the radially varying blur to an image.

Parameters:

  • image –
                        The image to blur, in NCHW format.
    
  • alpha –
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width –
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
                        Ignored in equirectangular mode (equi==True)
    
  • real_viewing_distance –
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
                        Ignored in equirectangular mode (equi==True)
    
  • centre –
                        The centre of the radially varying blur (the gaze location).
                        Should be a tuple of floats containing normalised image coordinates in range [0,1]
                        In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    
  • mode –
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • equi –
                        If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The centre argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    

Returns:

  • output ( tensor ) –

    The blurred image

Source code in odak/learn/perception/radially_varying_blur.py
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
    """
    Apply the radially varying blur to an image.

    Parameters
    ----------

    image                   : torch.tensor
                                The image to blur, in NCHW format.
    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
                                Ignored in equirectangular mode (equi==True)
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
                                Ignored in equirectangular mode (equi==True)
    centre                  : tuple of floats
                                The centre of the radially varying blur (the gaze location).
                                Should be a tuple of floats containing normalised image coordinates in range [0,1]
                                In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    equi                    : bool
                                If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The centre argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]

    Returns
    -------

    output                  : torch.tensor
                                The blurred image
    """
    size = (image.size(-2), image.size(-1))

    # LOD map caching
    if self.lod_map is None or\
            self.size != size or\
            self.n_channels != image.size(1) or\
            self.alpha != alpha or\
            self.real_image_width != real_image_width or\
            self.real_viewing_distance != real_viewing_distance or\
            self.centre != centre or\
            self.mode != mode or\
            self.equi != equi:
        if not equi:
            self.lod_map = make_pooling_size_map_lod(
                centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
        else:
            self.lod_map = make_equi_pooling_size_map_lod(
                centre, (image.size(-2), image.size(-1)), alpha, mode)
        self.size = size
        self.n_channels = image.size(1)
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.centre = centre
        self.lod_map = self.lod_map.to(image.device)
        self.lod_fraction = torch.fmod(self.lod_map, 1.0)
        self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
            1, image.size(1), 1, 1)
        self.mode = mode
        self.equi = equi

    if self.lod_map.device != image.device:
        self.lod_map = self.lod_map.to(image.device)
    if self.lod_fraction.device != image.device:
        self.lod_fraction = self.lod_fraction.to(image.device)

    mipmap = [image]
    while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
        mipmap.append(torch.nn.functional.interpolate(
            mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
    if mipmap[-1].size(-1) == 2:
        final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
        mipmap.append(final_mip)
    if mipmap[-1].size(-2) == 2:
        final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
        mipmap.append(final_mip)

    for l in range(len(mipmap)):
        if l == len(mipmap)-1:
            mipmap[l] = mipmap[l] * \
                torch.ones(image.size(), device=image.device)
        else:
            for l2 in range(l-1, -1, -1):
                mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
                    image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)

    output = torch.zeros(image.size(), device=image.device)
    for l in range(len(mipmap)):
        if l == 0:
            mask = self.lod_map < (l+1)
        elif l == len(mipmap)-1:
            mask = self.lod_map >= l
        else:
            mask = torch.logical_and(
                self.lod_map >= l, self.lod_map < (l+1))

        if l == len(mipmap)-1:
            blended_levels = mipmap[l]
        else:
            blended_levels = (1 - self.lod_fraction) * \
                mipmap[l] + self.lod_fraction*mipmap[l+1]
        mask = mask[None, None, ...]
        mask = mask.repeat(1, image.size(1), 1, 1)
        output[mask] = blended_levels[mask]

    return output

SSIM

Bases: Module

A class to calculate structural similarity index of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class SSIM(nn.Module):
    '''
    A class to calculate structural similarity index of an image with respect to a ground truth image.
    '''

    def __init__(self):
        super(SSIM, self).__init__()

    def forward(self, predictions, targets):
        """
        Parameters
        ----------
        predictions : torch.tensor
                      The predicted images.
        targets     : torch.tensor
                      The ground truth images.

        Returns
        -------
        result      : torch.tensor 
                      The computed SSIM value if successful, otherwise 0.0.
        """
        try:
            from torchmetrics.functional.image import structural_similarity_index_measure
            if len(predictions.shape) == 3:
                predictions = predictions.unsqueeze(0)
                targets = targets.unsqueeze(0)
            l_SSIM = structural_similarity_index_measure(predictions, targets)
            return l_SSIM
        except Exception as e:
            logging.warning('SSIM failed to compute.')
            logging.warning(e)
            return torch.tensor(0.0)

forward(predictions, targets)

Parameters:

  • predictions (tensor) –
          The predicted images.
    
  • targets –
          The ground truth images.
    

Returns:

  • result ( tensor ) –

    The computed SSIM value if successful, otherwise 0.0.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets):
    """
    Parameters
    ----------
    predictions : torch.tensor
                  The predicted images.
    targets     : torch.tensor
                  The ground truth images.

    Returns
    -------
    result      : torch.tensor 
                  The computed SSIM value if successful, otherwise 0.0.
    """
    try:
        from torchmetrics.functional.image import structural_similarity_index_measure
        if len(predictions.shape) == 3:
            predictions = predictions.unsqueeze(0)
            targets = targets.unsqueeze(0)
        l_SSIM = structural_similarity_index_measure(predictions, targets)
        return l_SSIM
    except Exception as e:
        logging.warning('SSIM failed to compute.')
        logging.warning(e)
        return torch.tensor(0.0)

SpatialSteerablePyramid

This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution) as opposed to multiplication in the Fourier domain. This has a number of optimisations over previous implementations that increase efficiency, but introduce some reconstruction error.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
class SpatialSteerablePyramid():
    """
    This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution)
    as opposed to multiplication in the Fourier domain.
    This has a number of optimisations over previous implementations that increase efficiency, but introduce some
    reconstruction error.
    """


    def __init__(self, use_bilinear_downup=True, n_channels=1,
                 filter_size=9, n_orientations=6, filter_type="full",
                 device=torch.device('cpu')):
        """
        Parameters
        ----------

        use_bilinear_downup     : bool
                                    This uses bilinear filtering when upsampling/downsampling, rather than the original approach
                                    of applying a large lowpass kernel and sampling even rows/columns
        n_channels              : int
                                    Number of channels in the input images (e.g. 3 for RGB input)
        filter_size             : int
                                    Desired size of filters (e.g. 3 will use 3x3 filters).
        n_orientations          : int
                                    Number of oriented bands in each level of the pyramid.
        filter_type             : str
                                    This can be used to select smaller filters than the original ones if desired.
                                    full: Original filter sizes
                                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
        device                  : torch.device
                                    torch device the input images will be supplied from.
        """
        self.use_bilinear_downup = use_bilinear_downup
        self.device = device

        filters = get_steerable_pyramid_filters(
            filter_size, n_orientations, filter_type)

        def make_pad(filter):
            filter_size = filter.size(-1)
            pad_amt = (filter_size-1) // 2
            return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))

        if not self.use_bilinear_downup:
            self.filt_l = filters["l"].to(device)
            self.pad_l = make_pad(self.filt_l)
        self.filt_l0 = filters["l0"].to(device)
        self.pad_l0 = make_pad(self.filt_l0)
        self.filt_h0 = filters["h0"].to(device)
        self.pad_h0 = make_pad(self.filt_h0)
        for b in range(len(filters["b"])):
            filters["b"][b] = filters["b"][b].to(device)
        self.band_filters = filters["b"]
        self.pad_b = make_pad(self.band_filters[0])

        if n_channels != 1:
            def add_channels_to_filter(filter):
                padded = torch.zeros(n_channels, n_channels, filter.size()[
                                     2], filter.size()[3]).to(device)
                for channel in range(n_channels):
                    padded[channel, channel, :, :] = filter
                return padded
            self.filt_h0 = add_channels_to_filter(self.filt_h0)
            for b in range(len(self.band_filters)):
                self.band_filters[b] = add_channels_to_filter(
                    self.band_filters[b])
            self.filt_l0 = add_channels_to_filter(self.filt_l0)
            if not self.use_bilinear_downup:
                self.filt_l = add_channels_to_filter(self.filt_l)

    def construct_pyramid(self, image, n_levels, multiple_highpass=False):
        """
        Constructs and returns a steerable pyramid for the provided image.

        Parameters
        ----------

        image               : torch.tensor
                                The input image, in NCHW format. The number of channels C should match num_channels
                                when the pyramid maker was created.
        n_levels            : int
                                Number of levels in the constructed steerable pyramid.
        multiple_highpass   : bool
                                If true, computes a highpass for each level of the pyramid.
                                These extra levels are redundant (not used for reconstruction).

        Returns
        -------

        pyramid             : list of dicts of torch.tensor
                                The computed steerable pyramid.
                                Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
                                Each level is stored as a dict, with the following keys:
                                "h" Highpass residual
                                "l" Lowpass residual
                                "b" Oriented bands (a list of torch.tensor)
        """
        pyramid = []

        # Make level 0, containing highpass, lowpass and the bands
        level0 = {}
        level0['h'] = torch.nn.functional.conv2d(
            self.pad_h0(image), self.filt_h0)
        lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
        level0['l'] = lowpass.clone()
        bands = []
        for filt_b in self.band_filters:
            bands.append(torch.nn.functional.conv2d(
                self.pad_b(lowpass), filt_b))
        level0['b'] = bands
        pyramid.append(level0)

        # Make intermediate levels
        for l in range(n_levels-2):
            level = {}
            if self.use_bilinear_downup:
                lowpass = torch.nn.functional.interpolate(
                    lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
            else:
                lowpass = torch.nn.functional.conv2d(
                    self.pad_l(lowpass), self.filt_l)
                lowpass = lowpass[:, :, ::2, ::2]
            level['l'] = lowpass.clone()
            bands = []
            for filt_b in self.band_filters:
                bands.append(torch.nn.functional.conv2d(
                    self.pad_b(lowpass), filt_b))
            level['b'] = bands
            if multiple_highpass:
                level['h'] = torch.nn.functional.conv2d(
                    self.pad_h0(lowpass), self.filt_h0)
            pyramid.append(level)

        # Make final level (lowpass residual)
        level = {}
        if self.use_bilinear_downup:
            lowpass = torch.nn.functional.interpolate(
                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
        else:
            lowpass = torch.nn.functional.conv2d(
                self.pad_l(lowpass), self.filt_l)
            lowpass = lowpass[:, :, ::2, ::2]
        level['l'] = lowpass
        pyramid.append(level)

        return pyramid

    def reconstruct_from_pyramid(self, pyramid):
        """
        Reconstructs an input image from a steerable pyramid.

        Parameters
        ----------

        pyramid : list of dicts of torch.tensor
                    The steerable pyramid.
                    Should be in the same format as output by construct_steerable_pyramid().
                    The number of channels should match num_channels when the pyramid maker was created.

        Returns
        -------

        image   : torch.tensor
                    The reconstructed image, in NCHW format.         
        """
        def upsample(image, size):
            if self.use_bilinear_downup:
                return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
            else:
                zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
                                    2]*2, image.size()[3]*2)).to(self.device)
                zeros[:, :, ::2, ::2] = image
                zeros = torch.nn.functional.conv2d(
                    self.pad_l(zeros), self.filt_l)
                return zeros

        image = pyramid[-1]['l']
        for level in reversed(pyramid[:-1]):
            image = upsample(image, level['b'][0].size()[2:])
            for b in range(len(level['b'])):
                b_filtered = torch.nn.functional.conv2d(
                    self.pad_b(level['b'][b]), -self.band_filters[b])
                image += b_filtered

        image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
        image += torch.nn.functional.conv2d(
            self.pad_h0(pyramid[0]['h']), self.filt_h0)

        return image

__init__(use_bilinear_downup=True, n_channels=1, filter_size=9, n_orientations=6, filter_type='full', device=torch.device('cpu'))

Parameters:

  • use_bilinear_downup –
                        This uses bilinear filtering when upsampling/downsampling, rather than the original approach
                        of applying a large lowpass kernel and sampling even rows/columns
    
  • n_channels –
                        Number of channels in the input images (e.g. 3 for RGB input)
    
  • filter_size –
                        Desired size of filters (e.g. 3 will use 3x3 filters).
    
  • n_orientations –
                        Number of oriented bands in each level of the pyramid.
    
  • filter_type –
                        This can be used to select smaller filters than the original ones if desired.
                        full: Original filter sizes
                        cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                        trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    
  • device –
                        torch device the input images will be supplied from.
    
Source code in odak/learn/perception/spatial_steerable_pyramid.py
def __init__(self, use_bilinear_downup=True, n_channels=1,
             filter_size=9, n_orientations=6, filter_type="full",
             device=torch.device('cpu')):
    """
    Parameters
    ----------

    use_bilinear_downup     : bool
                                This uses bilinear filtering when upsampling/downsampling, rather than the original approach
                                of applying a large lowpass kernel and sampling even rows/columns
    n_channels              : int
                                Number of channels in the input images (e.g. 3 for RGB input)
    filter_size             : int
                                Desired size of filters (e.g. 3 will use 3x3 filters).
    n_orientations          : int
                                Number of oriented bands in each level of the pyramid.
    filter_type             : str
                                This can be used to select smaller filters than the original ones if desired.
                                full: Original filter sizes
                                cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                                trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    device                  : torch.device
                                torch device the input images will be supplied from.
    """
    self.use_bilinear_downup = use_bilinear_downup
    self.device = device

    filters = get_steerable_pyramid_filters(
        filter_size, n_orientations, filter_type)

    def make_pad(filter):
        filter_size = filter.size(-1)
        pad_amt = (filter_size-1) // 2
        return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))

    if not self.use_bilinear_downup:
        self.filt_l = filters["l"].to(device)
        self.pad_l = make_pad(self.filt_l)
    self.filt_l0 = filters["l0"].to(device)
    self.pad_l0 = make_pad(self.filt_l0)
    self.filt_h0 = filters["h0"].to(device)
    self.pad_h0 = make_pad(self.filt_h0)
    for b in range(len(filters["b"])):
        filters["b"][b] = filters["b"][b].to(device)
    self.band_filters = filters["b"]
    self.pad_b = make_pad(self.band_filters[0])

    if n_channels != 1:
        def add_channels_to_filter(filter):
            padded = torch.zeros(n_channels, n_channels, filter.size()[
                                 2], filter.size()[3]).to(device)
            for channel in range(n_channels):
                padded[channel, channel, :, :] = filter
            return padded
        self.filt_h0 = add_channels_to_filter(self.filt_h0)
        for b in range(len(self.band_filters)):
            self.band_filters[b] = add_channels_to_filter(
                self.band_filters[b])
        self.filt_l0 = add_channels_to_filter(self.filt_l0)
        if not self.use_bilinear_downup:
            self.filt_l = add_channels_to_filter(self.filt_l)

construct_pyramid(image, n_levels, multiple_highpass=False)

Constructs and returns a steerable pyramid for the provided image.

Parameters:

  • image –
                    The input image, in NCHW format. The number of channels C should match num_channels
                    when the pyramid maker was created.
    
  • n_levels –
                    Number of levels in the constructed steerable pyramid.
    
  • multiple_highpass –
                    If true, computes a highpass for each level of the pyramid.
                    These extra levels are redundant (not used for reconstruction).
    

Returns:

  • pyramid ( list of dicts of torch.tensor ) –

    The computed steerable pyramid. Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels. Each level is stored as a dict, with the following keys: "h" Highpass residual "l" Lowpass residual "b" Oriented bands (a list of torch.tensor)

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def construct_pyramid(self, image, n_levels, multiple_highpass=False):
    """
    Constructs and returns a steerable pyramid for the provided image.

    Parameters
    ----------

    image               : torch.tensor
                            The input image, in NCHW format. The number of channels C should match num_channels
                            when the pyramid maker was created.
    n_levels            : int
                            Number of levels in the constructed steerable pyramid.
    multiple_highpass   : bool
                            If true, computes a highpass for each level of the pyramid.
                            These extra levels are redundant (not used for reconstruction).

    Returns
    -------

    pyramid             : list of dicts of torch.tensor
                            The computed steerable pyramid.
                            Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
                            Each level is stored as a dict, with the following keys:
                            "h" Highpass residual
                            "l" Lowpass residual
                            "b" Oriented bands (a list of torch.tensor)
    """
    pyramid = []

    # Make level 0, containing highpass, lowpass and the bands
    level0 = {}
    level0['h'] = torch.nn.functional.conv2d(
        self.pad_h0(image), self.filt_h0)
    lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
    level0['l'] = lowpass.clone()
    bands = []
    for filt_b in self.band_filters:
        bands.append(torch.nn.functional.conv2d(
            self.pad_b(lowpass), filt_b))
    level0['b'] = bands
    pyramid.append(level0)

    # Make intermediate levels
    for l in range(n_levels-2):
        level = {}
        if self.use_bilinear_downup:
            lowpass = torch.nn.functional.interpolate(
                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
        else:
            lowpass = torch.nn.functional.conv2d(
                self.pad_l(lowpass), self.filt_l)
            lowpass = lowpass[:, :, ::2, ::2]
        level['l'] = lowpass.clone()
        bands = []
        for filt_b in self.band_filters:
            bands.append(torch.nn.functional.conv2d(
                self.pad_b(lowpass), filt_b))
        level['b'] = bands
        if multiple_highpass:
            level['h'] = torch.nn.functional.conv2d(
                self.pad_h0(lowpass), self.filt_h0)
        pyramid.append(level)

    # Make final level (lowpass residual)
    level = {}
    if self.use_bilinear_downup:
        lowpass = torch.nn.functional.interpolate(
            lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
    else:
        lowpass = torch.nn.functional.conv2d(
            self.pad_l(lowpass), self.filt_l)
        lowpass = lowpass[:, :, ::2, ::2]
    level['l'] = lowpass
    pyramid.append(level)

    return pyramid

reconstruct_from_pyramid(pyramid)

Reconstructs an input image from a steerable pyramid.

Parameters:

  • pyramid (list of dicts of torch.tensor) –
        The steerable pyramid.
        Should be in the same format as output by construct_steerable_pyramid().
        The number of channels should match num_channels when the pyramid maker was created.
    

Returns:

  • image ( tensor ) –

    The reconstructed image, in NCHW format.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def reconstruct_from_pyramid(self, pyramid):
    """
    Reconstructs an input image from a steerable pyramid.

    Parameters
    ----------

    pyramid : list of dicts of torch.tensor
                The steerable pyramid.
                Should be in the same format as output by construct_steerable_pyramid().
                The number of channels should match num_channels when the pyramid maker was created.

    Returns
    -------

    image   : torch.tensor
                The reconstructed image, in NCHW format.         
    """
    def upsample(image, size):
        if self.use_bilinear_downup:
            return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
        else:
            zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
                                2]*2, image.size()[3]*2)).to(self.device)
            zeros[:, :, ::2, ::2] = image
            zeros = torch.nn.functional.conv2d(
                self.pad_l(zeros), self.filt_l)
            return zeros

    image = pyramid[-1]['l']
    for level in reversed(pyramid[:-1]):
        image = upsample(image, level['b'][0].size()[2:])
        for b in range(len(level['b'])):
            b_filtered = torch.nn.functional.conv2d(
                self.pad_b(level['b'][b]), -self.band_filters[b])
            image += b_filtered

    image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
    image += torch.nn.functional.conv2d(
        self.pad_h0(pyramid[0]['h']), self.filt_h0)

    return image

display_color_hvs

Source code in odak/learn/perception/color_conversion.py
class display_color_hvs():

    def __init__(
                 self,
                 resolution = [1920, 1080],
                 distance_from_screen = 800,
                 pixel_pitch = 0.311,
                 read_spectrum = 'tensor',
                 primaries_spectrum = torch.rand(3, 301),
                 device = torch.device('cpu')):
        '''
        Parameters
        ----------
        resolution                  : list
                                      Resolution of the display in pixels.
        distance_from_screen        : int
                                      Distance from the screen in mm.
        pixel_pitch                 : float
                                      Pixel pitch of the display in mm.
        read_spectrum               : str
                                      Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
        device                      : torch.device
                                      Device to run the code on. Default is None which means the code will run on CPU.

        '''
        self.device = device
        self.read_spectrum = read_spectrum
        self.primaries_spectrum = primaries_spectrum.to(self.device)
        self.resolution = resolution
        self.distance_from_screen = distance_from_screen
        self.pixel_pitch = pixel_pitch
        self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()
        self.lms_tensor = self.construct_matrix_lms(
                                                    self.l_normalized,
                                                    self.m_normalized,
                                                    self.s_normalized
                                                   )   
        self.primaries_tensor = self.construct_matrix_primaries(
                                                                self.l_normalized,
                                                                self.m_normalized,
                                                                self.s_normalized
                                                               )   
        return


    def __call__(self, input_image, ground_truth, gaze=None):
        """
        Evaluating an input image against a target ground truth image for a given gaze of a viewer.
        """
        lms_image_second = self.primaries_to_lms(input_image.to(self.device))
        lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
        lms_image_third = self.second_to_third_stage(lms_image_second)
        lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
        loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
        return loss_metamer_color


    def initialize_cones_normalized(self):
        """
        Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. 

        Returns
        -------
        l_cone_n                     : torch.tensor
                                       Normalised L cone distribution.
        m_cone_n                     : torch.tensor
                                       Normalised M cone distribution.
        s_cone_n                     : torch.tensor
                                       Normalised S cone distribution.
        """
        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
        dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))
        dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))
        dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))

        l_cone_n = dist_l / dist_l.max()
        m_cone_n = dist_m / dist_m.max()
        s_cone_n = dist_s / dist_s.max()
        return l_cone_n, m_cone_n, s_cone_n


    def initialize_rgb_backlight_spectrum(self):
        """
        Internal function to initialize baclight spectrum for color primaries. 

        Returns
        -------
        red_spectrum                 : torch.tensor
                                       Normalised backlight spectrum for red color primary.
        green_spectrum               : torch.tensor
                                       Normalised backlight spectrum for green color primary.
        blue_spectrum                : torch.tensor
                                       Normalised backlight spectrum for blue color primary.
        """
        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
        red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))
        green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))
        blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))

        red_spectrum = red_spectrum / red_spectrum.max()
        green_spectrum = green_spectrum / green_spectrum.max()
        blue_spectrum = blue_spectrum / blue_spectrum.max()

        return red_spectrum, green_spectrum, blue_spectrum


    def initialize_random_spectrum_normalized(self, dataset):
        """
        Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 

        Parameters
        ----------
        dataset                                : torch.tensor 
                                                 spectrum value against wavelength 
        """
        dataset = torch.swapaxes(dataset, 0, 1)
        x_spectrum = torch.linspace(400, 700, steps = 301) - 550
        y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))
        max_spectrum = torch.max(y_spectrum)
        y_spectrum /= max_spectrum

        def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
            torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))

        def function(x, weights): 
            return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])

        weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)
        optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)

        def closure():
            optimizer.zero_grad()
            output = function(x_spectrum, weights)
            loss = F.mse_loss(output, y_spectrum)
            loss.backward()
            return loss
        optimizer.step(closure)
        spectrum = function(x_spectrum, weights)
        return spectrum.detach().to(self.device)


    def display_spectrum_response(wavelength, function):
        """
        Internal function to provide light spectrum response at particular wavelength

        Parameters
        ----------
        wavelength                          : torch.tensor
                                              Wavelength in nm [400...700]
        function                            : torch.tensor
                                              Display light spectrum distribution function

        Returns
        -------
        ligth_response_dict                  : float
                                               Display light spectrum response value
        """
        wavelength = int(round(wavelength, 0))
        if wavelength >= 400 and wavelength <= 700:
            return function[wavelength - 400].item()
        elif wavelength < 400:
            return function[0].item()
        else:
            return function[300].item()


    def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
        """
        Internal function to calculate cone response at particular light spectrum. 

        Parameters
        ----------
        cone_spectrum                         : torch.tensor
                                                Spectrum, Wavelength [2,300] tensor 
        light_spectrum                        : torch.tensor
                                                Spectrum, Wavelength [2,300] tensor 


        Returns
        -------
        response_to_spectrum                  : float
                                                Response of cone to light spectrum [1x1] 
        """
        response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
        response_to_spectrum = torch.sum(response_to_spectrum)
        return response_to_spectrum.item()


    def construct_matrix_lms(self, l_response, m_response, s_response):
        '''
        Internal function to calculate cone  response at particular light spectrum. 

        Parameters
        ----------
        l_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalized response vs wavelength)
        m_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalized response vs wavelength)
        s_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalized response vs wavelength)



        Returns
        -------
        lms_image_tensor                      : torch.tensor
                                                3x3 LMSrgb tensor

        '''
        if self.read_spectrum == 'tensor':
            logging.warning('Tensor primary spectrum is used')
            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
        else:
            logging.warning("No Spectrum data is provided")

        self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
        for i in range(self.primaries_spectrum.shape[0]):
            self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])
            self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])
            self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) 
        return self.lms_tensor    


    def construct_matrix_primaries(self, l_response, m_response, s_response):
        '''
        Internal function to calculate cone  response at particular light spectrum. 

        Parameters
        ----------
        l_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalized response vs wavelength)
        m_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalized response vs wavelength)
        s_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalized response vs wavelength)



        Returns
        -------
        lms_image_tensor                      : torch.tensor
                                                3x3 LMSrgb tensor

        '''
        if self.read_spectrum == 'tensor':
            logging.warning('Tensor primary spectrum is used')
            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
        else:
            logging.warning("No Spectrum data is provided")

        self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
        for i in range(self.primaries_spectrum.shape[0]):
            self.primaries_tensor[0, i] = self.cone_response_to_spectrum(
                                                                         l_response,
                                                                         self.primaries_spectrum[i]
                                                                        )
            self.primaries_tensor[1, i] = self.cone_response_to_spectrum(
                                                                         m_response,
                                                                         self.primaries_spectrum[i]
                                                                        )
            self.primaries_tensor[2, i] = self.cone_response_to_spectrum(
                                                                         s_response,
                                                                         self.primaries_spectrum[i]
                                                                        ) 
        return self.primaries_tensor    


    def primaries_to_lms(self, primaries):
        """
        Internal function to convert primaries space to LMS space 

        Parameters
        ----------
        primaries                              : torch.tensor
                                                 Primaries data to be transformed to LMS space [BxPHxW]


        Returns
        -------
        lms_color                              : torch.tensor
                                                 LMS data transformed from Primaries space [BxPxHxW]
        """                
        primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)
        lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)
        lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)
        return lms_color


    def lms_to_primaries(self, lms_color_tensor):
        """
        Internal function to convert LMS image to primaries space

        Parameters
        ----------
        lms_color_tensor                        : torch.tensor
                                                  LMS data to be transformed to primaries space [Bx3xHxW]


        Returns
        -------
        primaries                              : torch.tensor
                                               : Primaries data transformed from LMS space [BxPxHxW]
        """
        lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
        lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
        unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
        converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())
        primaries = unflatten(converted_unflatten)     
        primaries = primaries.permute(0, 3, 1, 2)   
        return primaries


    def second_to_third_stage(self, lms_image):
        '''
        This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], 
        See table 1 from Schmidt et al. "Neurobiological hypothesis of color appearance and hue perception," Optics Express 2014.

        Parameters
        ----------
        lms_image                             : torch.tensor
                                                 Image data at LMS space (second stage)

        Returns
        -------
        third_stage                            : torch.tensor
                                                 Image data at LMS space (third stage)

        '''
        third_stage = torch.zeros_like(lms_image)
        third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]
        third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]
        third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]
        return third_stage

__call__(input_image, ground_truth, gaze=None)

Evaluating an input image against a target ground truth image for a given gaze of a viewer.

Source code in odak/learn/perception/color_conversion.py
def __call__(self, input_image, ground_truth, gaze=None):
    """
    Evaluating an input image against a target ground truth image for a given gaze of a viewer.
    """
    lms_image_second = self.primaries_to_lms(input_image.to(self.device))
    lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
    lms_image_third = self.second_to_third_stage(lms_image_second)
    lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
    loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
    return loss_metamer_color

__init__(resolution=[1920, 1080], distance_from_screen=800, pixel_pitch=0.311, read_spectrum='tensor', primaries_spectrum=torch.rand(3, 301), device=torch.device('cpu'))

Parameters:

  • resolution –
                          Resolution of the display in pixels.
    
  • distance_from_screen –
                          Distance from the screen in mm.
    
  • pixel_pitch –
                          Pixel pitch of the display in mm.
    
  • read_spectrum –
                          Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    
  • device –
                          Device to run the code on. Default is None which means the code will run on CPU.
    
Source code in odak/learn/perception/color_conversion.py
def __init__(
             self,
             resolution = [1920, 1080],
             distance_from_screen = 800,
             pixel_pitch = 0.311,
             read_spectrum = 'tensor',
             primaries_spectrum = torch.rand(3, 301),
             device = torch.device('cpu')):
    '''
    Parameters
    ----------
    resolution                  : list
                                  Resolution of the display in pixels.
    distance_from_screen        : int
                                  Distance from the screen in mm.
    pixel_pitch                 : float
                                  Pixel pitch of the display in mm.
    read_spectrum               : str
                                  Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    device                      : torch.device
                                  Device to run the code on. Default is None which means the code will run on CPU.

    '''
    self.device = device
    self.read_spectrum = read_spectrum
    self.primaries_spectrum = primaries_spectrum.to(self.device)
    self.resolution = resolution
    self.distance_from_screen = distance_from_screen
    self.pixel_pitch = pixel_pitch
    self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()
    self.lms_tensor = self.construct_matrix_lms(
                                                self.l_normalized,
                                                self.m_normalized,
                                                self.s_normalized
                                               )   
    self.primaries_tensor = self.construct_matrix_primaries(
                                                            self.l_normalized,
                                                            self.m_normalized,
                                                            self.s_normalized
                                                           )   
    return

cone_response_to_spectrum(cone_spectrum, light_spectrum)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • cone_spectrum –
                                    Spectrum, Wavelength [2,300] tensor
    
  • light_spectrum –
                                    Spectrum, Wavelength [2,300] tensor
    

Returns:

  • response_to_spectrum ( float ) –

    Response of cone to light spectrum [1x1]

Source code in odak/learn/perception/color_conversion.py
def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
    """
    Internal function to calculate cone response at particular light spectrum. 

    Parameters
    ----------
    cone_spectrum                         : torch.tensor
                                            Spectrum, Wavelength [2,300] tensor 
    light_spectrum                        : torch.tensor
                                            Spectrum, Wavelength [2,300] tensor 


    Returns
    -------
    response_to_spectrum                  : float
                                            Response of cone to light spectrum [1x1] 
    """
    response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
    response_to_spectrum = torch.sum(response_to_spectrum)
    return response_to_spectrum.item()

construct_matrix_lms(l_response, m_response, s_response)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response –
                                     Cone response spectrum tensor (normalized response vs wavelength)
    
  • m_response –
                                     Cone response spectrum tensor (normalized response vs wavelength)
    
  • s_response –
                                     Cone response spectrum tensor (normalized response vs wavelength)
    

Returns:

  • lms_image_tensor ( tensor ) –

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_lms(self, l_response, m_response, s_response):
    '''
    Internal function to calculate cone  response at particular light spectrum. 

    Parameters
    ----------
    l_response                             : torch.tensor
                                             Cone response spectrum tensor (normalized response vs wavelength)
    m_response                             : torch.tensor
                                             Cone response spectrum tensor (normalized response vs wavelength)
    s_response                             : torch.tensor
                                             Cone response spectrum tensor (normalized response vs wavelength)



    Returns
    -------
    lms_image_tensor                      : torch.tensor
                                            3x3 LMSrgb tensor

    '''
    if self.read_spectrum == 'tensor':
        logging.warning('Tensor primary spectrum is used')
        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
    else:
        logging.warning("No Spectrum data is provided")

    self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
    for i in range(self.primaries_spectrum.shape[0]):
        self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])
        self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])
        self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) 
    return self.lms_tensor    

construct_matrix_primaries(l_response, m_response, s_response)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response –
                                     Cone response spectrum tensor (normalized response vs wavelength)
    
  • m_response –
                                     Cone response spectrum tensor (normalized response vs wavelength)
    
  • s_response –
                                     Cone response spectrum tensor (normalized response vs wavelength)
    

Returns:

  • lms_image_tensor ( tensor ) –

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_primaries(self, l_response, m_response, s_response):
    '''
    Internal function to calculate cone  response at particular light spectrum. 

    Parameters
    ----------
    l_response                             : torch.tensor
                                             Cone response spectrum tensor (normalized response vs wavelength)
    m_response                             : torch.tensor
                                             Cone response spectrum tensor (normalized response vs wavelength)
    s_response                             : torch.tensor
                                             Cone response spectrum tensor (normalized response vs wavelength)



    Returns
    -------
    lms_image_tensor                      : torch.tensor
                                            3x3 LMSrgb tensor

    '''
    if self.read_spectrum == 'tensor':
        logging.warning('Tensor primary spectrum is used')
        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
    else:
        logging.warning("No Spectrum data is provided")

    self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
    for i in range(self.primaries_spectrum.shape[0]):
        self.primaries_tensor[0, i] = self.cone_response_to_spectrum(
                                                                     l_response,
                                                                     self.primaries_spectrum[i]
                                                                    )
        self.primaries_tensor[1, i] = self.cone_response_to_spectrum(
                                                                     m_response,
                                                                     self.primaries_spectrum[i]
                                                                    )
        self.primaries_tensor[2, i] = self.cone_response_to_spectrum(
                                                                     s_response,
                                                                     self.primaries_spectrum[i]
                                                                    ) 
    return self.primaries_tensor    

display_spectrum_response(wavelength, function)

Internal function to provide light spectrum response at particular wavelength

Parameters:

  • wavelength –
                                  Wavelength in nm [400...700]
    
  • function –
                                  Display light spectrum distribution function
    

Returns:

  • ligth_response_dict ( float ) –

    Display light spectrum response value

Source code in odak/learn/perception/color_conversion.py
def display_spectrum_response(wavelength, function):
    """
    Internal function to provide light spectrum response at particular wavelength

    Parameters
    ----------
    wavelength                          : torch.tensor
                                          Wavelength in nm [400...700]
    function                            : torch.tensor
                                          Display light spectrum distribution function

    Returns
    -------
    ligth_response_dict                  : float
                                           Display light spectrum response value
    """
    wavelength = int(round(wavelength, 0))
    if wavelength >= 400 and wavelength <= 700:
        return function[wavelength - 400].item()
    elif wavelength < 400:
        return function[0].item()
    else:
        return function[300].item()

initialize_cones_normalized()

Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values.

Returns:

  • l_cone_n ( tensor ) –

    Normalised L cone distribution.

  • m_cone_n ( tensor ) –

    Normalised M cone distribution.

  • s_cone_n ( tensor ) –

    Normalised S cone distribution.

Source code in odak/learn/perception/color_conversion.py
def initialize_cones_normalized(self):
    """
    Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. 

    Returns
    -------
    l_cone_n                     : torch.tensor
                                   Normalised L cone distribution.
    m_cone_n                     : torch.tensor
                                   Normalised M cone distribution.
    s_cone_n                     : torch.tensor
                                   Normalised S cone distribution.
    """
    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
    dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))
    dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))