Skip to content

odak.learn.perception

odak.learn.perception

Defines a number of different perceptual loss functions, which can be used to optimise images where gaze location is known.

BlurLoss

BlurLoss implements two different blur losses. When blur_source is set to False, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

When blur_source is set to True, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

The interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/blur_loss.py
class BlurLoss():
    """ 

    `BlurLoss` implements two different blur losses. When `blur_source` is set to `False`, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

    When `blur_source` is set to `True`, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

    The interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device("cpu"),
                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
        """
        Parameters
        ----------

        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        blur_source             : bool
                                    If true, blurs the source image as well as the target before computing the loss.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.device = device
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.mode = mode
        self.blur = None
        self.loss_func = torch.nn.MSELoss()
        self.blur_source = blur_source
        self.equi = equi

    def blur_image(self, image, gaze):
        if self.blur is None:
            self.blur = RadiallyVaryingBlur()
        return self.blur.blur(image, self.alpha, self.real_image_width, self.real_viewing_distance, gaze, self.mode, self.equi)

    def __call__(self, image, target, gaze=[0.5, 0.5]):
        """ 
        Calculates the Blur Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        gaze                : list
                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("BlurLoss", image, target)
        blurred_target = self.blur_image(target, gaze)
        if self.blur_source:
            blurred_image = self.blur_image(image, gaze)
            return self.loss_func(blurred_image, blurred_target)
        else:
            return self.loss_func(image, blurred_target)

    def to(self, device):
        self.device = device
        return self

__call__(image, target, gaze=[0.5, 0.5])

Calculates the Blur Loss.

Parameters:

  • image
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/blur_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):
    """ 
    Calculates the Blur Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    gaze                : list
                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("BlurLoss", image, target)
    blurred_target = self.blur_image(target, gaze)
    if self.blur_source:
        blurred_image = self.blur_image(image, gaze)
        return self.loss_func(blurred_image, blurred_target)
    else:
        return self.loss_func(image, blurred_target)

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', blur_source=False, equi=False)

Parameters:

  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • blur_source
                        If true, blurs the source image as well as the target before computing the loss.
    
  • equi
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/blur_loss.py
def __init__(self, device=torch.device("cpu"),
             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
    """
    Parameters
    ----------

    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    blur_source             : bool
                                If true, blurs the source image as well as the target before computing the loss.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.device = device
    self.alpha = alpha
    self.real_image_width = real_image_width
    self.real_viewing_distance = real_viewing_distance
    self.mode = mode
    self.blur = None
    self.loss_func = torch.nn.MSELoss()
    self.blur_source = blur_source
    self.equi = equi

display_color_hvs

Source code in odak/learn/perception/color_conversion.py
class display_color_hvs():

    def __init__(self, resolution = [1920, 1080],
                 distance_from_screen = 800,
                 pixel_pitch = 0.311,
                 read_spectrum = 'tensor',
                 primaries_spectrum = torch.rand(3, 301),
                 device = torch.device('cpu')):
        '''
        Parameters
        ----------
        resolution                  : list
                                      Resolution of the display in pixels.
        distance_from_screen        : int
                                      Distance from the screen in mm.
        pixel_pitch                 : float
                                      Pixel pitch of the display in mm.
        read_spectrum               : str
                                      Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
        device                      : torch.device
                                      Device to run the code on. Default is None which means the code will run on CPU.

        '''
        self.device = device
        self.read_spectrum = read_spectrum
        self.primaries_spectrum = primaries_spectrum.to(self.device)
        self.resolution = resolution
        self.distance_from_screen = distance_from_screen
        self.pixel_pitch = pixel_pitch
        self.l_normalised, self.m_normalised, self.s_normalised = self.initialize_cones_normalised()
        self.lms_tensor = self.construct_matrix_lms(
                                                    self.l_normalised,
                                                    self.m_normalised,
                                                    self.s_normalised
                                                   )   
        self.primaries_tensor = self.construct_matrix_primaries(
                                                    self.l_normalised,
                                                    self.m_normalised,
                                                    self.s_normalised
                                                   )   
        return


    def __call__(self, input_image, ground_truth, gaze=None):
        """
        Evaluating an input image against a target ground truth image for a given gaze of a viewer.
        """
        lms_image_second = self.primaries_to_lms(input_image.to(self.device))
        lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
        lms_image_third = self.second_to_third_stage(lms_image_second)
        lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
        loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
        return loss_metamer_color


    def initialize_cones_normalised(self):
        """
        Internal function to initialize normalised L,M,S cones as normal distribution with given sigma, and mu values. 

        Returns
        -------
        l_cone_n                     : torch.tensor
                                       Normalised L cone distribution.
        m_cone_n                     : torch.tensor
                                       Normalised M cone distribution.
        s_cone_n                     : torch.tensor
                                       Normalised S cone distribution.
        """
        wavelength_range = np_cpu.linspace(400, 700, num=301)
        dist_l = [1 / (32.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                               567.5)**2 / (2 * 32.5**2)) for i in range(len(wavelength_range))]
        dist_m = [1 / (27.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                               545.0)**2 / (2 * 27.5**2)) for i in range(len(wavelength_range))]
        dist_s = [1 / (17.0 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                               447.5)**2 / (2 * 17.0**2)) for i in range(len(wavelength_range))]

        l_cone_n = torch.from_numpy(dist_l/max(dist_l))
        m_cone_n = torch.from_numpy(dist_m/max(dist_m))
        s_cone_n = torch.from_numpy(dist_s/max(dist_s))
        return l_cone_n.to(self.device), m_cone_n.to(self.device), s_cone_n.to(self.device)


    def initialize_rgb_backlight_spectrum(self):
        """
        Internal function to initialize baclight spectrum for color primaries. 

        Returns
        -------
        red_spectrum                 : torch.tensor
                                       Normalised backlight spectrum for red color primary.
        green_spectrum               : torch.tensor
                                       Normalised backlight spectrum for green color primary.
        blue_spectrum                : torch.tensor
                                       Normalised backlight spectrum for blue color primary.
        """
        wavelength_range = np_cpu.linspace(400, 700, num=301)
        red_spectrum = [1 / (14.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
            wavelength_range[i] - 650)**2 / (2 * 14.5**2)) for i in range(len(wavelength_range))]
        green_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
            wavelength_range[i] - 550)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]
        blue_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
            wavelength_range[i] - 450)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]

        red_spectrum = torch.from_numpy(
            red_spectrum / max(red_spectrum)) * 1.0
        green_spectrum = torch.from_numpy(
            green_spectrum / max(green_spectrum)) * 1.0
        blue_spectrum = torch.from_numpy(
            blue_spectrum / max(blue_spectrum)) * 1.0

        return red_spectrum.to(self.device), green_spectrum.to(self.device), blue_spectrum.to(self.device)


    def initialize_random_spectrum_normalised(self, dataset):
        """
        Initialize normalised light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 

        Parameters
        ----------
        dataset                                : torch.tensor 
                                                 spectrum value against wavelength 
        """
        if (type(dataset).__module__) == "torch":
            dataset = dataset.numpy()
        if dataset.shape[0] > dataset.shape[1]:
            dataset = np_cpu.swapaxes(dataset, 0, 1)
        x_spectrum = np_cpu.linspace(400, 700, num=301)
        y_spectrum = np_cpu.interp(x_spectrum, dataset[0], dataset[1])
        x_spectrum = torch.from_numpy(x_spectrum) - 550
        y_spectrum = torch.from_numpy(y_spectrum)
        max_spectrum = torch.max(y_spectrum)
        y_spectrum /= max_spectrum

        def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
            torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))

        def function(x, weights): return gaussian(
            x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
        weights = torch.tensor(
            [1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad=True)
        optimizer = torch.optim.LBFGS(
            [weights], max_iter = 1000, lr = 0.1, line_search_fn=None)

        def closure():
            optimizer.zero_grad()
            output = function(x_spectrum, weights)
            loss = F.mse_loss(output, y_spectrum)
            loss.backward()
            return loss
        optimizer.step(closure)
        spectrum = function(x_spectrum, weights)
        return spectrum.detach().to(self.device)


    def display_spectrum_response(wavelength, function):
        """
        Internal function to provide light spectrum response at particular wavelength

        Parameters
        ----------
        wavelength                          : torch.tensor
                                              Wavelength in nm [400...700]
        function                            : torch.tensor
                                              Display light spectrum distribution function

        Returns
        -------
        ligth_response_dict                  : float
                                               Display light spectrum response value
        """
        wavelength = int(round(wavelength, 0))
        if wavelength >= 400 and wavelength <= 700:
            return function[wavelength - 400].item()
        elif wavelength < 400:
            return function[0].item()
        else:
            return function[300].item()


    def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
        """
        Internal function to calculate cone response at particular light spectrum. 

        Parameters
        ----------
        cone_spectrum                         : torch.tensor
                                                Spectrum, Wavelength [2,300] tensor 
        light_spectrum                        : torch.tensor
                                                Spectrum, Wavelength [2,300] tensor 


        Returns
        -------
        response_to_spectrum                  : float
                                                Response of cone to light spectrum [1x1] 
        """
        response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
        response_to_spectrum = torch.sum(response_to_spectrum)
        return response_to_spectrum.item()


    def construct_matrix_lms(self, l_response, m_response, s_response):
        '''
        Internal function to calculate cone  response at particular light spectrum. 

        Parameters
        ----------
        l_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        m_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        s_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)



        Returns
        -------
        lms_image_tensor                      : torch.tensor
                                                3x3 LMSrgb tensor

        '''
        if self.read_spectrum == 'tensor':
            logging.warning('Tensor primary spectrum is used')
            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
        else:
            logging.warning("No Spectrum data is provided")

        self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
        for i in range(self.primaries_spectrum.shape[0]):
            self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response,
                                                                   self.primaries_spectrum[i]
                                                                   ) 
        return self.lms_tensor    

    def construct_matrix_primaries(self, l_response, m_response, s_response):
        '''
        Internal function to calculate cone  response at particular light spectrum. 

        Parameters
        ----------
        l_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        m_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)
        s_response                             : torch.tensor
                                                 Cone response spectrum tensor (normalised response vs wavelength)



        Returns
        -------
        lms_image_tensor                      : torch.tensor
                                                3x3 LMSrgb tensor

        '''
        if self.read_spectrum == 'tensor':
            logging.warning('Tensor primary spectrum is used')
            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
        else:
            logging.warning("No Spectrum data is provided")

        self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
        for i in range(self.primaries_spectrum.shape[0]):
            self.primaries_tensor[0, i] = self.cone_response_to_spectrum(l_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.primaries_tensor[1, i] = self.cone_response_to_spectrum(m_response,
                                                                   self.primaries_spectrum[i]
                                                                   )
            self.primaries_tensor[2, i] = self.cone_response_to_spectrum(s_response,
                                                                   self.primaries_spectrum[i]
                                                                   ) 
        return self.primaries_tensor    


    def primaries_to_lms(self, primaries):
        """
        Internal function to convert primaries space to LMS space 

        Parameters
        ----------
        primaries                              : torch.tensor
                                                 Primaries data to be transformed to LMS space [BxPHxW]


        Returns
        -------
        lms_color                              : torch.tensor
                                                 LMS data transformed from Primaries space [BxPxHxW]
        """                
        primaries = primaries.permute(0, 2, 3, 1).to(self.device)
        primaries_flatten = torch.flatten(primaries, start_dim = 1, end_dim = 2)
        unflatten = torch.nn.Unflatten(1, (primaries.size(1), primaries.size(2)))
        converted_unflatten = torch.matmul(primaries_flatten.double(), self.lms_tensor.double())
        lms_color = unflatten(converted_unflatten)        
        lms_color = lms_color.permute(0, 3, 1, 2)
        return lms_color


    def lms_to_primaries(self, lms_color_tensor):
        """
        Internal function to convert LMS image to primaries space

        Parameters
        ----------
        lms_color_tensor                        : torch.tensor
                                                  LMS data to be transformed to primaries space [Bx3xHxW]


        Returns
        -------
        primaries                              : torch.tensor
                                               : Primaries data transformed from LMS space [BxPxHxW]
        """
        lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
        lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
        unflatten = torch.nn.Unflatten(
            0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
        converted_unflatten = torch.matmul(
            lms_color_flatten.double(), self.lms_tensor.pinverse().double())
        primaries = unflatten(converted_unflatten)     
        primaries = primaries.permute(0, 3, 1, 2)   
        return primaries


    def second_to_third_stage(self, lms_image):
        '''
        This function turns second stage [L,M,S] values into third stage [(L+S)-M, M-(L+S), (M+S)-L]
        Equations are taken from Schmidt et al "Neurobiological hypothesis of color appearance and hue perception" 2014

        Parameters
        ----------
        lms_image                             : torch.tensor
                                                 Image data at LMS space (second stage)

        Returns
        -------
        third_stage                            : torch.tensor
                                                 Image data at LMS space (third stage)

        '''
        lms_image = lms_image.permute(0,2,3,1)
        third_stage = torch.zeros(lms_image.shape[0],
            lms_image.shape[1], lms_image.shape[2], 3).to(self.device)
        third_stage[:, :, :, 0] = (lms_image[:, :, :, 1] +
                                lms_image[:, :, :, 2]) - lms_image[:, :, :, 0]
        third_stage[:, :, :, 1] = (lms_image[:, :, :, 0] +
                                lms_image[:, :, :, 2]) - lms_image[:, :, :, 1]
        third_stage[:, :, :, 2] = torch.sum(lms_image, dim=3) / 3.
        third_stage = third_stage.permute(0, 3, 1, 2)
        return third_stage


    def to(self, device):
        """
        Utilization function for setting the device.
        Parameters
        ----------
        device       : torch.device
                       Device to be used (e.g., CPU, Cuda, OpenCL).
        """
        self.device = device
        return self

__call__(input_image, ground_truth, gaze=None)

Evaluating an input image against a target ground truth image for a given gaze of a viewer.

Source code in odak/learn/perception/color_conversion.py
def __call__(self, input_image, ground_truth, gaze=None):
    """
    Evaluating an input image against a target ground truth image for a given gaze of a viewer.
    """
    lms_image_second = self.primaries_to_lms(input_image.to(self.device))
    lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
    lms_image_third = self.second_to_third_stage(lms_image_second)
    lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
    loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
    return loss_metamer_color

__init__(resolution=[1920, 1080], distance_from_screen=800, pixel_pitch=0.311, read_spectrum='tensor', primaries_spectrum=torch.rand(3, 301), device=torch.device('cpu'))

Parameters:

  • resolution
                          Resolution of the display in pixels.
    
  • distance_from_screen
                          Distance from the screen in mm.
    
  • pixel_pitch
                          Pixel pitch of the display in mm.
    
  • read_spectrum
                          Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    
  • device
                          Device to run the code on. Default is None which means the code will run on CPU.
    
Source code in odak/learn/perception/color_conversion.py
def __init__(self, resolution = [1920, 1080],
             distance_from_screen = 800,
             pixel_pitch = 0.311,
             read_spectrum = 'tensor',
             primaries_spectrum = torch.rand(3, 301),
             device = torch.device('cpu')):
    '''
    Parameters
    ----------
    resolution                  : list
                                  Resolution of the display in pixels.
    distance_from_screen        : int
                                  Distance from the screen in mm.
    pixel_pitch                 : float
                                  Pixel pitch of the display in mm.
    read_spectrum               : str
                                  Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    device                      : torch.device
                                  Device to run the code on. Default is None which means the code will run on CPU.

    '''
    self.device = device
    self.read_spectrum = read_spectrum
    self.primaries_spectrum = primaries_spectrum.to(self.device)
    self.resolution = resolution
    self.distance_from_screen = distance_from_screen
    self.pixel_pitch = pixel_pitch
    self.l_normalised, self.m_normalised, self.s_normalised = self.initialize_cones_normalised()
    self.lms_tensor = self.construct_matrix_lms(
                                                self.l_normalised,
                                                self.m_normalised,
                                                self.s_normalised
                                               )   
    self.primaries_tensor = self.construct_matrix_primaries(
                                                self.l_normalised,
                                                self.m_normalised,
                                                self.s_normalised
                                               )   
    return

cone_response_to_spectrum(cone_spectrum, light_spectrum)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • cone_spectrum
                                    Spectrum, Wavelength [2,300] tensor
    
  • light_spectrum
                                    Spectrum, Wavelength [2,300] tensor
    

Returns:

  • response_to_spectrum ( float ) –

    Response of cone to light spectrum [1x1]

Source code in odak/learn/perception/color_conversion.py
def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
    """
    Internal function to calculate cone response at particular light spectrum. 

    Parameters
    ----------
    cone_spectrum                         : torch.tensor
                                            Spectrum, Wavelength [2,300] tensor 
    light_spectrum                        : torch.tensor
                                            Spectrum, Wavelength [2,300] tensor 


    Returns
    -------
    response_to_spectrum                  : float
                                            Response of cone to light spectrum [1x1] 
    """
    response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
    response_to_spectrum = torch.sum(response_to_spectrum)
    return response_to_spectrum.item()

construct_matrix_lms(l_response, m_response, s_response)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • m_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • s_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    

Returns:

  • lms_image_tensor ( tensor ) –

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_lms(self, l_response, m_response, s_response):
    '''
    Internal function to calculate cone  response at particular light spectrum. 

    Parameters
    ----------
    l_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    m_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    s_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)



    Returns
    -------
    lms_image_tensor                      : torch.tensor
                                            3x3 LMSrgb tensor

    '''
    if self.read_spectrum == 'tensor':
        logging.warning('Tensor primary spectrum is used')
        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
    else:
        logging.warning("No Spectrum data is provided")

    self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
    for i in range(self.primaries_spectrum.shape[0]):
        self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response,
                                                               self.primaries_spectrum[i]
                                                               ) 
    return self.lms_tensor    

construct_matrix_primaries(l_response, m_response, s_response)

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • m_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    
  • s_response
                                     Cone response spectrum tensor (normalised response vs wavelength)
    

Returns:

  • lms_image_tensor ( tensor ) –

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_primaries(self, l_response, m_response, s_response):
    '''
    Internal function to calculate cone  response at particular light spectrum. 

    Parameters
    ----------
    l_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    m_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)
    s_response                             : torch.tensor
                                             Cone response spectrum tensor (normalised response vs wavelength)



    Returns
    -------
    lms_image_tensor                      : torch.tensor
                                            3x3 LMSrgb tensor

    '''
    if self.read_spectrum == 'tensor':
        logging.warning('Tensor primary spectrum is used')
        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
    else:
        logging.warning("No Spectrum data is provided")

    self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
    for i in range(self.primaries_spectrum.shape[0]):
        self.primaries_tensor[0, i] = self.cone_response_to_spectrum(l_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.primaries_tensor[1, i] = self.cone_response_to_spectrum(m_response,
                                                               self.primaries_spectrum[i]
                                                               )
        self.primaries_tensor[2, i] = self.cone_response_to_spectrum(s_response,
                                                               self.primaries_spectrum[i]
                                                               ) 
    return self.primaries_tensor    

display_spectrum_response(wavelength, function)

Internal function to provide light spectrum response at particular wavelength

Parameters:

  • wavelength
                                  Wavelength in nm [400...700]
    
  • function
                                  Display light spectrum distribution function
    

Returns:

  • ligth_response_dict ( float ) –

    Display light spectrum response value

Source code in odak/learn/perception/color_conversion.py
def display_spectrum_response(wavelength, function):
    """
    Internal function to provide light spectrum response at particular wavelength

    Parameters
    ----------
    wavelength                          : torch.tensor
                                          Wavelength in nm [400...700]
    function                            : torch.tensor
                                          Display light spectrum distribution function

    Returns
    -------
    ligth_response_dict                  : float
                                           Display light spectrum response value
    """
    wavelength = int(round(wavelength, 0))
    if wavelength >= 400 and wavelength <= 700:
        return function[wavelength - 400].item()
    elif wavelength < 400:
        return function[0].item()
    else:
        return function[300].item()

initialize_cones_normalised()

Internal function to initialize normalised L,M,S cones as normal distribution with given sigma, and mu values.

Returns:

  • l_cone_n ( tensor ) –

    Normalised L cone distribution.

  • m_cone_n ( tensor ) –

    Normalised M cone distribution.

  • s_cone_n ( tensor ) –

    Normalised S cone distribution.

Source code in odak/learn/perception/color_conversion.py
def initialize_cones_normalised(self):
    """
    Internal function to initialize normalised L,M,S cones as normal distribution with given sigma, and mu values. 

    Returns
    -------
    l_cone_n                     : torch.tensor
                                   Normalised L cone distribution.
    m_cone_n                     : torch.tensor
                                   Normalised M cone distribution.
    s_cone_n                     : torch.tensor
                                   Normalised S cone distribution.
    """
    wavelength_range = np_cpu.linspace(400, 700, num=301)
    dist_l = [1 / (32.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                           567.5)**2 / (2 * 32.5**2)) for i in range(len(wavelength_range))]
    dist_m = [1 / (27.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                           545.0)**2 / (2 * 27.5**2)) for i in range(len(wavelength_range))]
    dist_s = [1 / (17.0 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (wavelength_range[i] -
                                                                           447.5)**2 / (2 * 17.0**2)) for i in range(len(wavelength_range))]

    l_cone_n = torch.from_numpy(dist_l/max(dist_l))
    m_cone_n = torch.from_numpy(dist_m/max(dist_m))
    s_cone_n = torch.from_numpy(dist_s/max(dist_s))
    return l_cone_n.to(self.device), m_cone_n.to(self.device), s_cone_n.to(self.device)

initialize_random_spectrum_normalised(dataset)

Initialize normalised light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS].

Parameters:

  • dataset
                                     spectrum value against wavelength
    
Source code in odak/learn/perception/color_conversion.py
def initialize_random_spectrum_normalised(self, dataset):
    """
    Initialize normalised light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 

    Parameters
    ----------
    dataset                                : torch.tensor 
                                             spectrum value against wavelength 
    """
    if (type(dataset).__module__) == "torch":
        dataset = dataset.numpy()
    if dataset.shape[0] > dataset.shape[1]:
        dataset = np_cpu.swapaxes(dataset, 0, 1)
    x_spectrum = np_cpu.linspace(400, 700, num=301)
    y_spectrum = np_cpu.interp(x_spectrum, dataset[0], dataset[1])
    x_spectrum = torch.from_numpy(x_spectrum) - 550
    y_spectrum = torch.from_numpy(y_spectrum)
    max_spectrum = torch.max(y_spectrum)
    y_spectrum /= max_spectrum

    def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
        torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))

    def function(x, weights): return gaussian(
        x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
    weights = torch.tensor(
        [1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad=True)
    optimizer = torch.optim.LBFGS(
        [weights], max_iter = 1000, lr = 0.1, line_search_fn=None)

    def closure():
        optimizer.zero_grad()
        output = function(x_spectrum, weights)
        loss = F.mse_loss(output, y_spectrum)
        loss.backward()
        return loss
    optimizer.step(closure)
    spectrum = function(x_spectrum, weights)
    return spectrum.detach().to(self.device)

initialize_rgb_backlight_spectrum()

Internal function to initialize baclight spectrum for color primaries.

Returns:

  • red_spectrum ( tensor ) –

    Normalised backlight spectrum for red color primary.

  • green_spectrum ( tensor ) –

    Normalised backlight spectrum for green color primary.

  • blue_spectrum ( tensor ) –

    Normalised backlight spectrum for blue color primary.

Source code in odak/learn/perception/color_conversion.py
def initialize_rgb_backlight_spectrum(self):
    """
    Internal function to initialize baclight spectrum for color primaries. 

    Returns
    -------
    red_spectrum                 : torch.tensor
                                   Normalised backlight spectrum for red color primary.
    green_spectrum               : torch.tensor
                                   Normalised backlight spectrum for green color primary.
    blue_spectrum                : torch.tensor
                                   Normalised backlight spectrum for blue color primary.
    """
    wavelength_range = np_cpu.linspace(400, 700, num=301)
    red_spectrum = [1 / (14.5 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
        wavelength_range[i] - 650)**2 / (2 * 14.5**2)) for i in range(len(wavelength_range))]
    green_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
        wavelength_range[i] - 550)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]
    blue_spectrum = [1 / (12 * np_cpu.sqrt(2 * np_cpu.pi)) * np_cpu.exp(-0.5 * (
        wavelength_range[i] - 450)**2 / (2 * 12.0**2)) for i in range(len(wavelength_range))]

    red_spectrum = torch.from_numpy(
        red_spectrum / max(red_spectrum)) * 1.0
    green_spectrum = torch.from_numpy(
        green_spectrum / max(green_spectrum)) * 1.0
    blue_spectrum = torch.from_numpy(
        blue_spectrum / max(blue_spectrum)) * 1.0

    return red_spectrum.to(self.device), green_spectrum.to(self.device), blue_spectrum.to(self.device)

lms_to_primaries(lms_color_tensor)

Internal function to convert LMS image to primaries space

Parameters:

  • lms_color_tensor
                                      LMS data to be transformed to primaries space [Bx3xHxW]
    

Returns:

  • primaries ( tensor ) –

    : Primaries data transformed from LMS space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def lms_to_primaries(self, lms_color_tensor):
    """
    Internal function to convert LMS image to primaries space

    Parameters
    ----------
    lms_color_tensor                        : torch.tensor
                                              LMS data to be transformed to primaries space [Bx3xHxW]


    Returns
    -------
    primaries                              : torch.tensor
                                           : Primaries data transformed from LMS space [BxPxHxW]
    """
    lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
    lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
    unflatten = torch.nn.Unflatten(
        0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
    converted_unflatten = torch.matmul(
        lms_color_flatten.double(), self.lms_tensor.pinverse().double())
    primaries = unflatten(converted_unflatten)     
    primaries = primaries.permute(0, 3, 1, 2)   
    return primaries

primaries_to_lms(primaries)

Internal function to convert primaries space to LMS space

Parameters:

  • primaries
                                     Primaries data to be transformed to LMS space [BxPHxW]
    

Returns:

  • lms_color ( tensor ) –

    LMS data transformed from Primaries space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def primaries_to_lms(self, primaries):
    """
    Internal function to convert primaries space to LMS space 

    Parameters
    ----------
    primaries                              : torch.tensor
                                             Primaries data to be transformed to LMS space [BxPHxW]


    Returns
    -------
    lms_color                              : torch.tensor
                                             LMS data transformed from Primaries space [BxPxHxW]
    """                
    primaries = primaries.permute(0, 2, 3, 1).to(self.device)
    primaries_flatten = torch.flatten(primaries, start_dim = 1, end_dim = 2)
    unflatten = torch.nn.Unflatten(1, (primaries.size(1), primaries.size(2)))
    converted_unflatten = torch.matmul(primaries_flatten.double(), self.lms_tensor.double())
    lms_color = unflatten(converted_unflatten)        
    lms_color = lms_color.permute(0, 3, 1, 2)
    return lms_color

second_to_third_stage(lms_image)

This function turns second stage [L,M,S] values into third stage [(L+S)-M, M-(L+S), (M+S)-L] Equations are taken from Schmidt et al "Neurobiological hypothesis of color appearance and hue perception" 2014

Parameters:

  • lms_image
                                     Image data at LMS space (second stage)
    

Returns:

  • third_stage ( tensor ) –

    Image data at LMS space (third stage)

Source code in odak/learn/perception/color_conversion.py
def second_to_third_stage(self, lms_image):
    '''
    This function turns second stage [L,M,S] values into third stage [(L+S)-M, M-(L+S), (M+S)-L]
    Equations are taken from Schmidt et al "Neurobiological hypothesis of color appearance and hue perception" 2014

    Parameters
    ----------
    lms_image                             : torch.tensor
                                             Image data at LMS space (second stage)

    Returns
    -------
    third_stage                            : torch.tensor
                                             Image data at LMS space (third stage)

    '''
    lms_image = lms_image.permute(0,2,3,1)
    third_stage = torch.zeros(lms_image.shape[0],
        lms_image.shape[1], lms_image.shape[2], 3).to(self.device)
    third_stage[:, :, :, 0] = (lms_image[:, :, :, 1] +
                            lms_image[:, :, :, 2]) - lms_image[:, :, :, 0]
    third_stage[:, :, :, 1] = (lms_image[:, :, :, 0] +
                            lms_image[:, :, :, 2]) - lms_image[:, :, :, 1]
    third_stage[:, :, :, 2] = torch.sum(lms_image, dim=3) / 3.
    third_stage = third_stage.permute(0, 3, 1, 2)
    return third_stage

to(device)

Utilization function for setting the device.

Parameters:

  • device
           Device to be used (e.g., CPU, Cuda, OpenCL).
    
Source code in odak/learn/perception/color_conversion.py
def to(self, device):
    """
    Utilization function for setting the device.
    Parameters
    ----------
    device       : torch.device
                   Device to be used (e.g., CPU, Cuda, OpenCL).
    """
    self.device = device
    return self

color_map(input_image, target_image, model='Lab Stats')

Internal function to map the color of an image to another image. Reference: Color transfer between images, Reinhard et al., 2001.

Parameters:

  • input_image
                  Input image in RGB color space [3 x m x n].
    
  • target_image

Returns:

  • mapped_image ( Tensor ) –

    Input image with the color the distribution of the target image [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def color_map(input_image, target_image, model = 'Lab Stats'):
    """
    Internal function to map the color of an image to another image.
    Reference: Color transfer between images, Reinhard et al., 2001.

    Parameters
    ----------
    input_image         : torch.Tensor
                          Input image in RGB color space [3 x m x n].
    target_image        : torch.Tensor

    Returns
    -------
    mapped_image           : torch.Tensor
                             Input image with the color the distribution of the target image [3 x m x n].
    """
    if model == 'Lab Stats':
        lab_input = srgb_to_lab(input_image)
        lab_target = srgb_to_lab(target_image)
        input_mean_L = torch.mean(lab_input[0, :, :])
        input_mean_a = torch.mean(lab_input[1, :, :])
        input_mean_b = torch.mean(lab_input[2, :, :])
        input_std_L = torch.std(lab_input[0, :, :])
        input_std_a = torch.std(lab_input[1, :, :])
        input_std_b = torch.std(lab_input[2, :, :])
        target_mean_L = torch.mean(lab_target[0, :, :])
        target_mean_a = torch.mean(lab_target[1, :, :])
        target_mean_b = torch.mean(lab_target[2, :, :])
        target_std_L = torch.std(lab_target[0, :, :])
        target_std_a = torch.std(lab_target[1, :, :])
        target_std_b = torch.std(lab_target[2, :, :])
        lab_input[0, :, :] = (lab_input[0, :, :] - input_mean_L) * (target_std_L / input_std_L) + target_mean_L
        lab_input[1, :, :] = (lab_input[1, :, :] - input_mean_a) * (target_std_a / input_std_a) + target_mean_a
        lab_input[2, :, :] = (lab_input[2, :, :] - input_mean_b) * (target_std_b / input_std_b) + target_mean_b
        mapped_image = lab_to_srgb(lab_input.permute(1, 2, 0))
        return mapped_image

hsv_to_rgb(image)

Definition to convert HSV space to RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_rgb ( tensor ) –

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def hsv_to_rgb(image):

    """
    Definition to convert HSV space to  RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_rgb       : torch.tensor
                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    h = image[..., 0, :, :] / (2 * math.pi)
    s = image[..., 1, :, :]
    v = image[..., 2, :, :]
    hi = torch.floor(h * 6) % 6
    f = ((h * 6) % 6) - hi
    one = torch.tensor(1.0)
    p = v * (one - s)
    q = v * (one - f * s)
    t = v * (one - (one - f) * s)
    hi = hi.long()
    indices = torch.stack([hi, hi + 6, hi + 12], dim=-3)
    image_rgb = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
    image_rgb = torch.gather(image_rgb, -3, indices)
    return image_rgb

lab_to_srgb(image)

Definition to convert LAB space to SRGB color space.

Parameters:

  • image
              Input image in LAB color space[3 x m x n]
    

Returns:

  • image_srgb ( tensor ) –

    Output image in SRGB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def lab_to_srgb(image):
    """
    Definition to convert LAB space to SRGB color space. 

    Parameters
    ----------
    image           : torch.tensor
                      Input image in LAB color space[3 x m x n]
    Returns
    -------
    image_srgb     : torch.tensor
                      Output image in SRGB color space [3 x m x n].
    """

    if image.shape[-1] == 3:
        input_color = image.permute(2, 0, 1)  # C(H*W)
    else:
        input_color = image
    # lab ---> xyz
    reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]], dtype=torch.float32)
    y = (input_color[0:1, :, :] + 16) / 116
    a =  input_color[1:2, :, :] / 500
    b =  input_color[2:3, :, :] / 200
    x = y + a
    z = y - b
    xyz = torch.cat((x, y, z), 0)
    delta = 6 / 29
    factor = 3 * delta * delta
    xyz = torch.where(xyz > delta,  xyz ** 3, factor * (xyz - 4 / 29))
    xyz_color = xyz * reference_illuminant
    # xyz ---> linear rgb
    a11 = 3.241003275
    a12 = -1.537398934
    a13 = -0.498615861
    a21 = -0.969224334
    a22 = 1.875930071
    a23 = 0.041554224
    a31 = 0.055639423
    a32 = -0.204011202
    a33 = 1.057148933
    A = torch.tensor([[a11, a12, a13],
                  [a21, a22, a23],
                  [a31, a32, a33]], dtype=torch.float32)

    xyz_color = xyz_color.permute(2, 0, 1) # C(H*W)
    linear_rgb_color = torch.matmul(A, xyz_color)
    linear_rgb_color = linear_rgb_color.permute(1, 2, 0)
    # linear rgb ---> srgb
    limit = 0.0031308
    image_srgb = torch.where(linear_rgb_color > limit, 1.055 * (linear_rgb_color ** (1.0 / 2.4)) - 0.055, 12.92 * linear_rgb_color)
    return image_srgb

linear_rgb_to_rgb(image, threshold=0.0031308)

Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    
  • threshold
              Threshold used in calculations.
    

Returns:

  • image_linear ( tensor ) –

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_rgb(image, threshold = 0.0031308):
    """
    Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    threshold       : float
                      Threshold used in calculations.

    Returns
    -------
    image_linear    : torch.tensor
                      Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    image_linear =  torch.where(image > threshold, 1.055 * torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image)
    return image_linear

linear_rgb_to_xyz(image)

Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_xyz ( tensor ) –

    Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_xyz(image):
    """
    Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

    Parameters
    ----------
    image           : torch.tensor
                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_xyz       : torch.tensor
                      Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    a11 = 0.412453
    a12 = 0.357580
    a13 = 0.180423
    a21 = 0.212671
    a22 = 0.715160
    a23 = 0.072169
    a31 = 0.019334
    a32 = 0.119193
    a33 = 0.950227
    M = torch.tensor([[a11, a12, a13], 
                      [a21, a22, a23],
                      [a31, a32, a33]])
    size = image.size()
    image = image.reshape(size[0], size[1], size[2]*size[3])  # NC(HW)
    image_xyz = torch.matmul(M, image)
    image_xyz = image_xyz.reshape(size[0], size[1], size[2], size[3])
    return image_xyz

rgb_2_ycrcb(image)

Converts an image from RGB colourspace to YCrCb colourspace.

Parameters:

  • image
      Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    

Returns:

  • ycrcb ( tensor ) –

    Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_2_ycrcb(image):
    """
    Converts an image from RGB colourspace to YCrCb colourspace.

    Parameters
    ----------
    image   : torch.tensor
              Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].

    Returns
    -------

    ycrcb   : torch.tensor
              Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
       image = image.unsqueeze(0)
    ycrcb = torch.zeros(image.size()).to(image.device)
    ycrcb[:, 0, :, :] = 0.299 * image[:, 0, :, :] + 0.587 * \
        image[:, 1, :, :] + 0.114 * image[:, 2, :, :]
    ycrcb[:, 1, :, :] = 0.5 + 0.713 * (image[:, 0, :, :] - ycrcb[:, 0, :, :])
    ycrcb[:, 2, :, :] = 0.5 + 0.564 * (image[:, 2, :, :] - ycrcb[:, 0, :, :])
    return ycrcb

rgb_to_hsv(image, eps=1e-08)

Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_hsv ( tensor ) –

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_hsv(image, eps: float = 1e-8):

    """
    Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_hsv       : torch.tensor
                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    max_rgb, argmax_rgb = image.max(-3)
    min_rgb, argmin_rgb = image.min(-3)
    deltac = max_rgb - min_rgb
    v = max_rgb
    s = deltac / (max_rgb + eps)
    deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
    rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
    h1 = bc - gc
    h2 = (rc - bc) + 2.0 * deltac
    h3 = (gc - rc) + 4.0 * deltac
    h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
    h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
    h = (h / 6.0) % 1.0
    h = 2.0 * math.pi * h 
    image_hsv = torch.stack((h, s, v), dim=-3)
    return image_hsv

rgb_to_linear_rgb(image, threshold=0.0031308)

Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image
              Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    
  • threshold
              Threshold used in calculations.
    

Returns:

  • image_linear ( tensor ) –

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_linear_rgb(image, threshold = 0.0031308):
    """
    Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

    Parameters
    ----------
    image           : torch.tensor
                      Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    threshold       : float
                      Threshold used in calculations.

    Returns
    -------
    image_linear    : torch.tensor
                      Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    image_linear = torch.where(image > 0.04045, torch.pow(((image + 0.055) / 1.055), 2.4), image / 12.92)
    return image_linear

srgb_to_lab(image)

Definition to convert SRGB space to LAB color space.

Parameters:

  • image
              Input image in SRGB color space[3 x m x n]
    

Returns:

  • image_lab ( tensor ) –

    Output image in LAB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def srgb_to_lab(image):    
    """
    Definition to convert SRGB space to LAB color space. 

    Parameters
    ----------
    image           : torch.tensor
                      Input image in SRGB color space[3 x m x n]
    Returns
    -------
    image_lab       : torch.tensor
                      Output image in LAB color space [3 x m x n].
    """
    if image.shape[-1] == 3:
        input_color = image.permute(2, 0, 1)  # C(H*W)
    else:
        input_color = image
    # rgb ---> linear rgb
    limit = 0.04045        
    # linear rgb ---> xyz
    linrgb_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92)

    a11 = 10135552 / 24577794
    a12 = 8788810  / 24577794
    a13 = 4435075  / 24577794
    a21 = 2613072  / 12288897
    a22 = 8788810  / 12288897
    a23 = 887015   / 12288897
    a31 = 1425312  / 73733382
    a32 = 8788810  / 73733382
    a33 = 70074185 / 73733382

    A = torch.tensor([[a11, a12, a13],
                    [a21, a22, a23],
                    [a31, a32, a33]], dtype=torch.float32)

    linrgb_color = linrgb_color.permute(2, 0, 1) # C(H*W)
    xyz_color = torch.matmul(A, linrgb_color)
    xyz_color = xyz_color.permute(1, 2, 0)
    # xyz ---> lab
    inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]], dtype=torch.float32)
    input_color = xyz_color * inv_reference_illuminant
    delta = 6 / 29
    delta_square = delta * delta
    delta_cube = delta * delta_square
    factor = 1 / (3 * delta_square)

    input_color = torch.where(input_color > delta_cube, torch.pow(input_color, 1 / 3), (factor * input_color + 4 / 29))

    l = 116 * input_color[1:2, :, :] - 16
    a = 500 * (input_color[0:1,:, :] - input_color[1:2, :, :])
    b = 200 * (input_color[1:2, :, :] - input_color[2:3, :, :])

    image_lab = torch.cat((l, a, b), 0)
    return image_lab    

xyz_to_linear_rgb(image)

Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image
               Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    

Returns:

  • image_linear_rgb ( tensor ) –

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def xyz_to_linear_rgb(image):
    """
    Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

    Parameters
    ----------
    image            : torch.tensor
                       Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.

    Returns
    -------
    image_linear_rgb : torch.tensor
                       Output image in linear RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    a11 = 3.240479
    a12 = -1.537150
    a13 = -0.498535
    a21 = -0.969256 
    a22 = 1.875992 
    a23 = 0.041556
    a31 = 0.055648
    a32 = -0.204043
    a33 = 1.057311
    M = torch.tensor([[a11, a12, a13], 
                      [a21, a22, a23],
                      [a31, a32, a33]])
    size = image.size()
    image = image.reshape(size[0], size[1], size[2]*size[3])
    image_linear_rgb = torch.matmul(M, image)
    image_linear_rgb = image_linear_rgb.reshape(size[0], size[1], size[2], size[3])
    return image_linear_rgb

ycrcb_2_rgb(image)

Converts an image from YCrCb colourspace to RGB colourspace.

Parameters:

  • image
      Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    

Returns:

  • rgb ( tensor ) –

    Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def ycrcb_2_rgb(image):
    """
    Converts an image from YCrCb colourspace to RGB colourspace.

    Parameters
    ----------
    image   : torch.tensor
              Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].

    Returns
    -------
    rgb     : torch.tensor
              Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].
    """
    if len(image.shape) == 3:
       image = image.unsqueeze(0)
    rgb = torch.zeros(image.size(), device=image.device)
    rgb[:, 0, :, :] = image[:, 0, :, :] + 1.403 * (image[:, 1, :, :] - 0.5)
    rgb[:, 1, :, :] = image[:, 0, :, :] - 0.714 * \
        (image[:, 1, :, :] - 0.5) - 0.344 * (image[:, 2, :, :] - 0.5)
    rgb[:, 2, :, :] = image[:, 0, :, :] + 1.773 * (image[:, 2, :, :] - 0.5)
    return rgb

make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)

Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction.

Parameters:

  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • map ( tensor ) –

    The computed 3D location map, of size 3xWxH.

Source code in odak/learn/perception/foveation.py
def make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
    """ 
    Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to
    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
    perpendicular to the viewing direction.

    Parameters
    ----------

    image_pixel_size        : tuple of ints 
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float 
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    map                     : torch.tensor
                                The computed 3D location map, of size 3xWxH.
    """
    real_image_height = (real_image_width /
                         image_pixel_size[-1]) * image_pixel_size[-2]
    x_coords = torch.linspace(-0.5, 0.5, image_pixel_size[-1])*real_image_width
    x_coords = x_coords[None, None, :].repeat(1, image_pixel_size[-2], 1)
    y_coords = torch.linspace(-0.5, 0.5,
                              image_pixel_size[-2])*real_image_height
    y_coords = y_coords[None, :, None].repeat(1, 1, image_pixel_size[-1])
    z_coords = torch.ones(
        (1, image_pixel_size[-2], image_pixel_size[-1])) * real_viewing_distance

    return torch.cat([x_coords, y_coords, z_coords], dim=0)

make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)

Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output in radians.

Parameters:

  • gaze_location
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                        image coordinates (ranging from 0 to 1)
    
  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • eccentricity_map ( tensor ) –

    The computed eccentricity map, of size WxH.

  • distance_map ( tensor ) –

    The computed distance map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
    """ 
    Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to
    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
    perpendicular to the viewing direction. Output in radians.

    Parameters
    ----------

    gaze_location           : tuple of floats
                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                                image coordinates (ranging from 0 to 1)
    image_pixel_size        : tuple of ints
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    eccentricity_map        : torch.tensor
                                The computed eccentricity map, of size WxH.
    distance_map            : torch.tensor
                                The computed distance map, of size WxH.
    """
    real_image_height = (real_image_width /
                         image_pixel_size[-1]) * image_pixel_size[-2]
    location_map = make_3d_location_map(
        image_pixel_size, real_image_width, real_viewing_distance)
    distance_map = torch.sqrt(torch.sum(location_map*location_map, dim=0))
    direction_map = location_map / distance_map

    gaze_location_3d = torch.tensor([
        (gaze_location[0]*2 - 1)*real_image_width*0.5,
        (gaze_location[1]*2 - 1)*real_image_height*0.5,
        real_viewing_distance])
    gaze_dir = gaze_location_3d / \
        torch.sqrt(torch.sum(gaze_location_3d * gaze_location_3d))
    gaze_dir = gaze_dir[:, None, None]

    dot_prod_map = torch.sum(gaze_dir * direction_map, dim=0)
    dot_prod_map = torch.clamp(dot_prod_map, min=-1.0, max=1.0)
    eccentricity_map = torch.acos(dot_prod_map)

    return eccentricity_map, distance_map

make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')

This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_angles
                    Gaze direction expressed as angles, in radians.
    
  • image_pixel_size
                    Dimensions of the image in pixels, as a tuple of (height, width)
    
  • alpha
                    Parameter controlling extent of foveation
    
  • mode
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    

Returns:

  • pooling_size_map ( tensor ) –

    The computed pooling size map, of size HxW.

Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
    """ 
    This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
    to achieve the correct pooling region areas.

    Parameters
    ----------

    gaze_angles         : tuple of 2 floats
                            Gaze direction expressed as angles, in radians.
    image_pixel_size    : tuple of 2 ints
                            Dimensions of the image in pixels, as a tuple of (height, width)
    alpha               : float
                            Parameter controlling extent of foveation
    mode                : str
                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"

    Returns
    -------

    pooling_size_map        : torch.tensor
                                The computed pooling size map, of size HxW.
    """
    pooling_pixel = make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha, mode)
    import matplotlib.pyplot as plt
    pooling_lod = torch.log2(1e-6+pooling_pixel)
    pooling_lod[pooling_lod < 0] = 0
    return pooling_lod

make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')

This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images. Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image corresponds to pitch, ranging from -pi/2 to pi/2.

In this setup real_image_width and real_viewing_distance have no effect.

Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

Parameters:

  • gaze_angles
                    Gaze direction expressed as angles, in radians.
    
  • image_pixel_size
                    Dimensions of the image in pixels, as a tuple of (height, width)
    
  • alpha
                    Parameter controlling extent of foveation
    
  • mode
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    
Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
    """
    This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images.
    Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, 
    the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image
    corresponds to pitch, ranging from -pi/2 to pi/2.

    In this setup real_image_width and real_viewing_distance have no effect.

    Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

    Parameters
    ----------

    gaze_angles         : tuple of 2 floats
                            Gaze direction expressed as angles, in radians.
    image_pixel_size    : tuple of 2 ints
                            Dimensions of the image in pixels, as a tuple of (height, width)
    alpha               : float
                            Parameter controlling extent of foveation
    mode                : str
                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    """
    view_direction = torch.tensor([math.sin(gaze_angles[0])*math.cos(gaze_angles[1]), math.sin(gaze_angles[1]), math.cos(gaze_angles[0])*math.cos(gaze_angles[1])])

    yaw_angle_map = torch.linspace(-torch.pi, torch.pi, image_pixel_size[1])
    yaw_angle_map = yaw_angle_map[None,:].repeat(image_pixel_size[0], 1)[None,...]
    pitch_angle_map = torch.linspace(-torch.pi*0.5, torch.pi*0.5, image_pixel_size[0])
    pitch_angle_map = pitch_angle_map[:,None].repeat(1, image_pixel_size[1])[None,...]

    dir_map = torch.cat([torch.sin(yaw_angle_map)*torch.cos(pitch_angle_map), torch.sin(pitch_angle_map), torch.cos(yaw_angle_map)*torch.cos(pitch_angle_map)])

    # Work out the pooling region diameter in radians
    view_dot_dir = torch.sum(view_direction[:,None,None] * dir_map, dim=0)
    eccentricity = torch.acos(view_dot_dir)
    pooling_rad = alpha * eccentricity
    if mode == "quadratic":
        pooling_rad *= eccentricity

    # The actual pooling region will be an ellipse in the equirectangular image - the length of the major & minor axes
    # depend on the x & y resolution of the image. We find these two axis lengths (in pixels) and then the area of the ellipse
    pixels_per_rad_x = image_pixel_size[1] / (2*torch.pi)
    pixels_per_rad_y = image_pixel_size[0] / (torch.pi)
    pooling_axis_x = pooling_rad * pixels_per_rad_x
    pooling_axis_y = pooling_rad * pixels_per_rad_y
    area = torch.pi * pooling_axis_x * pooling_axis_y * 0.25

    # Now finally find the length of the side of a square of the same area.
    size = torch.sqrt(torch.abs(area))
    return size

make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')

This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_location
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                        image coordinates (ranging from 0 to 1)
    
  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • pooling_size_map ( tensor ) –

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
    """ 
    This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
    to achieve the correct pooling region areas.

    Parameters
    ----------

    gaze_location           : tuple of floats
                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                                image coordinates (ranging from 0 to 1)
    image_pixel_size        : tuple of ints
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    pooling_size_map        : torch.tensor
                                The computed pooling size map, of size WxH.
    """
    pooling_pixel = make_pooling_size_map_pixels(
        gaze_location, image_pixel_size, alpha, real_image_width, real_viewing_distance, mode)
    pooling_lod = torch.log2(1e-6+pooling_pixel)
    pooling_lod[pooling_lod < 0] = 0
    return pooling_lod

make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')

Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity (also in radians).

Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

Parameters:

  • gaze_location
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                        image coordinates (ranging from 0 to 1)
    
  • image_pixel_size
                        The size of the image in pixels, as a tuple of form (height, width)
    
  • real_image_width
                        The real width of the image as displayed. Units not important, as long as they
                        are the same as those used for real_viewing_distance
    
  • real_viewing_distance
                        The real distance from the user's viewpoint to the screen.
    

Returns:

  • pooling_size_map ( tensor ) –

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
    """ 
    Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to
    a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity
    (also in radians). 

    Assumes the viewpoint is located at the centre of the image, and the screen is 
    perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

    Parameters
    ----------

    gaze_location           : tuple of floats
                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                                image coordinates (ranging from 0 to 1)
    image_pixel_size        : tuple of ints
                                The size of the image in pixels, as a tuple of form (height, width)
    real_image_width        : float
                                The real width of the image as displayed. Units not important, as long as they
                                are the same as those used for real_viewing_distance
    real_viewing_distance   : float
                                The real distance from the user's viewpoint to the screen.

    Returns
    -------

    pooling_size_map        : torch.tensor
                                The computed pooling size map, of size WxH.
    """
    eccentricity, distance_to_pixel = make_eccentricity_distance_maps(
        gaze_location, image_pixel_size, real_image_width, real_viewing_distance)
    eccentricity_centre, _ = make_eccentricity_distance_maps(
        [0.5, 0.5], image_pixel_size, real_image_width, real_viewing_distance)
    pooling_rad = alpha * eccentricity
    if mode == "quadratic":
        pooling_rad *= eccentricity
    angle_min = eccentricity_centre - pooling_rad*0.5
    angle_max = eccentricity_centre + pooling_rad*0.5
    major_axis = (torch.tan(angle_max) - torch.tan(angle_min)) * \
        real_viewing_distance
    minor_axis = 2 * distance_to_pixel * torch.tan(pooling_rad*0.5)
    area = math.pi * major_axis * minor_axis * 0.25
    # Should be +ve anyway, but check to ensure we don't take sqrt of negative number
    area = torch.abs(area)
    pooling_real = torch.sqrt(area)
    pooling_pixel = (pooling_real / real_image_width) * image_pixel_size[1]
    return pooling_pixel

make_radial_map(size, gaze)

Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

Parameters:

  • size
        Dimensions of the image
    
  • gaze
        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
        image coordinates (ranging from 0 to 1)
    
Source code in odak/learn/perception/foveation.py
def make_radial_map(size, gaze):
    """ 
    Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

    Parameters
    ----------

    size    : tuple of ints
                Dimensions of the image
    gaze    : tuple of floats
                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
                image coordinates (ranging from 0 to 1)
    """
    pix_gaze = [gaze[0]*size[0], gaze[1]*size[1]]
    rows = torch.linspace(0, size[0], size[0])
    rows = rows[:, None].repeat(1, size[1])
    cols = torch.linspace(0, size[1], size[1])
    cols = cols[None, :].repeat(size[0], 1)
    dist_sq = torch.pow(rows - pix_gaze[0], 2) + \
        torch.pow(cols - pix_gaze[1], 2)
    radii = torch.sqrt(dist_sq)
    return radii/torch.max(radii)

MetamericLoss

The MetamericLoss class provides a perceptual loss function.

Rather than exactly match the source image to the target, it tries to ensure the source is a metamer to the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metameric_loss.py
class MetamericLoss():
    """
    The `MetamericLoss` class provides a perceptual loss function.

    Rather than exactly match the source image to the target, it tries to ensure the source is a *metamer* to the target image.

    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
                 real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
                 n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
                 use_fullres_l0=False, equi=False):
        """
        Parameters
        ----------

        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        n_pyramid_levels        : int 
                                    Number of levels of the steerable pyramid. Note that the image is padded
                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                    too high will slow down the calculation a lot.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        n_orientations          : int 
                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                    Increasing this will increase runtime.
        use_l2_foveal_loss      : bool 
                                    If true, for all the pixels that have pooling size 1 pixel in the 
                                    largest scale will use direct L2 against target rather than pooling over pyramid levels.
                                    In practice this gives better results when the loss is used for holography.
        fovea_weight            : float 
                                    A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
        use_radial_weight       : bool 
                                    If True, will apply a radial weighting when calculating the difference between
                                    the source and target stats maps. This weights stats closer to the fovea more than those
                                    further away.
        use_fullres_l0          : bool 
                                    If true, stats for the lowpass residual are replaced with blurred versions
                                    of the full-resolution source and target images.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.device = device
        self.pyramid_maker = None
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.blurs = None
        self.n_pyramid_levels = n_pyramid_levels
        self.n_orientations = n_orientations
        self.mode = mode
        self.use_l2_foveal_loss = use_l2_foveal_loss
        self.fovea_weight = fovea_weight
        self.use_radial_weight = use_radial_weight
        self.use_fullres_l0 = use_fullres_l0
        self.equi = equi
        if self.use_fullres_l0 and self.use_l2_foveal_loss:
            raise Exception(
                "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")

    def calc_statsmaps(self, image, gaze=None, alpha=0.01, real_image_width=0.3,
                       real_viewing_distance=0.6, mode="quadratic", equi=False):

        if self.pyramid_maker is None or \
                self.pyramid_maker.device != self.device or \
                len(self.pyramid_maker.band_filters) != self.n_orientations or\
                self.pyramid_maker.filt_h0.size(0) != image.size(1):
            self.pyramid_maker = SpatialSteerablePyramid(
                use_bilinear_downup=False, n_channels=image.size(1),
                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)

        if self.blurs is None or len(self.blurs) != self.n_pyramid_levels:
            self.blurs = [RadiallyVaryingBlur()
                          for i in range(self.n_pyramid_levels)]

        def find_stats(image_pyr_level, blur):
            image_means = blur.blur(
                image_pyr_level, alpha, real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)
            image_meansq = blur.blur(image_pyr_level*image_pyr_level, alpha,
                                     real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)

            image_vars = image_meansq - (image_means*image_means)
            image_vars[image_vars < 1e-7] = 1e-7
            image_std = torch.sqrt(image_vars)
            if torch.any(torch.isnan(image_means)):
                print(image_means)
                raise Exception("NaN in image means!")
            if torch.any(torch.isnan(image_std)):
                print(image_std)
                raise Exception("NaN in image stdevs!")
            if self.use_fullres_l0:
                mask = blur.lod_map > 1e-6
                mask = mask[None, None, ...]
                if image_means.size(1) > 1:
                    mask = mask.repeat(1, image_means.size(1), 1, 1)
                matte = torch.zeros_like(image_means)
                matte[mask] = 1.0
                return image_means * matte, image_std * matte
            return image_means, image_std
        output_stats = []
        image_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)
        means, variances = find_stats(image_pyramid[0]['h'], self.blurs[0])
        if self.use_l2_foveal_loss:
            self.fovea_mask = torch.zeros(image.size(), device=image.device)
            for i in range(self.fovea_mask.size(1)):
                self.fovea_mask[0, i, ...] = 1.0 - \
                    (self.blurs[0].lod_map / torch.max(self.blurs[0].lod_map))
                self.fovea_mask[0, i, self.blurs[0].lod_map < 1e-6] = 1.0
            self.fovea_mask = torch.pow(self.fovea_mask, 10.0)
            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, scale_factor=0.125, mode="area")
            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, size=(image.size(-2), image.size(-1)), mode="bilinear")
            periphery_mask = 1.0 - self.fovea_mask
            self.periphery_mask = periphery_mask.clone()
            output_stats.append(means * periphery_mask)
            output_stats.append(variances * periphery_mask)
        else:
            output_stats.append(means)
            output_stats.append(variances)

        for l in range(0, len(image_pyramid)-1):
            for o in range(len(image_pyramid[l]['b'])):
                means, variances = find_stats(
                    image_pyramid[l]['b'][o], self.blurs[l])
                if self.use_l2_foveal_loss:
                    output_stats.append(means * periphery_mask)
                    output_stats.append(variances * periphery_mask)
                else:
                    output_stats.append(means)
                    output_stats.append(variances)
            if self.use_l2_foveal_loss:
                periphery_mask = torch.nn.functional.interpolate(
                    periphery_mask, scale_factor=0.5, mode="area", recompute_scale_factor=False)

        if self.use_l2_foveal_loss:
            output_stats.append(image_pyramid[-1]["l"] * periphery_mask)
        elif self.use_fullres_l0:
            output_stats.append(self.blurs[0].blur(
                image, alpha, real_image_width, real_viewing_distance, gaze, mode))
        else:
            output_stats.append(image_pyramid[-1]["l"])
        return output_stats

    def metameric_loss_stats(self, statsmap_a, statsmap_b, gaze):
        loss = 0.0
        for a, b in zip(statsmap_a, statsmap_b):
            if self.use_radial_weight:
                radii = make_radial_map(
                    [a.size(-2), a.size(-1)], gaze).to(a.device)
                weights = 1.1 - (radii * radii * radii * radii)
                weights = weights[None, None, ...].repeat(1, a.size(1), 1, 1)
                loss += torch.nn.MSELoss()(weights*a, weights*b)
            else:
                loss += torch.nn.MSELoss()(a, b)
        loss /= len(statsmap_a)
        return loss

    def visualise_loss_map(self, image_stats):
        loss_map = torch.zeros(image_stats[0].size()[-2:])
        for i in range(len(image_stats)):
            stats = image_stats[i]
            target_stats = self.target_stats[i]
            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
            loss_map += stat_mse_map[0, 0, ...]
        self.loss_map = loss_map

    def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
        """ 
        Calculates the Metameric Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        image_colorspace    : str
                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
                                accepted values: RGB, YCrCb.
        gaze                : list
                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
        visualise_loss      : bool
                                Shows a heatmap indicating which parts of the image contributed most to the loss. 

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamericLoss", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
        # If input is RGB, convert to YCrCb.
        if image.size(1) == 3 and image_colorspace == "RGB":
            image = rgb_2_ycrcb(image)
            target = rgb_2_ycrcb(target)
        if self.target is None:
            self.target = torch.zeros(target.shape).to(target.device)
        if type(target) == type(self.target):
            if not torch.all(torch.eq(target, self.target)):
                self.target = target.detach().clone()
                self.target_stats = self.calc_statsmaps(
                    self.target,
                    gaze=gaze,
                    alpha=self.alpha,
                    real_image_width=self.real_image_width,
                    real_viewing_distance=self.real_viewing_distance,
                    mode=self.mode
                )
                self.target = target.detach().clone()
            image_stats = self.calc_statsmaps(
                image,
                gaze=gaze,
                alpha=self.alpha,
                real_image_width=self.real_image_width,
                real_viewing_distance=self.real_viewing_distance,
                mode=self.mode
            )
            if visualise_loss:
                self.visualise_loss_map(image_stats)
            if self.use_l2_foveal_loss:
                peripheral_loss = self.metameric_loss_stats(
                    image_stats, self.target_stats, gaze)
                foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
                # New weighting - evenly weight fovea and periphery.
                loss = peripheral_loss + self.fovea_weight * foveal_loss
            else:
                loss = self.metameric_loss_stats(
                    image_stats, self.target_stats, gaze)
            return loss
        else:
            raise Exception("Target of incorrect type")

    def to(self, device):
        self.device = device
        return self

__call__(image, target, gaze=[0.5, 0.5], image_colorspace='RGB', visualise_loss=False)

Calculates the Metameric Loss.

Parameters:

  • image
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • image_colorspace
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
                    accepted values: RGB, YCrCb.
    
  • gaze
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    
  • visualise_loss
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metameric_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
    """ 
    Calculates the Metameric Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    image_colorspace    : str
                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
                            accepted values: RGB, YCrCb.
    gaze                : list
                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    visualise_loss      : bool
                            Shows a heatmap indicating which parts of the image contributed most to the loss. 

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamericLoss", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
    # If input is RGB, convert to YCrCb.
    if image.size(1) == 3 and image_colorspace == "RGB":
        image = rgb_2_ycrcb(image)
        target = rgb_2_ycrcb(target)
    if self.target is None:
        self.target = torch.zeros(target.shape).to(target.device)
    if type(target) == type(self.target):
        if not torch.all(torch.eq(target, self.target)):
            self.target = target.detach().clone()
            self.target_stats = self.calc_statsmaps(
                self.target,
                gaze=gaze,
                alpha=self.alpha,
                real_image_width=self.real_image_width,
                real_viewing_distance=self.real_viewing_distance,
                mode=self.mode
            )
            self.target = target.detach().clone()
        image_stats = self.calc_statsmaps(
            image,
            gaze=gaze,
            alpha=self.alpha,
            real_image_width=self.real_image_width,
            real_viewing_distance=self.real_viewing_distance,
            mode=self.mode
        )
        if visualise_loss:
            self.visualise_loss_map(image_stats)
        if self.use_l2_foveal_loss:
            peripheral_loss = self.metameric_loss_stats(
                image_stats, self.target_stats, gaze)
            foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
            # New weighting - evenly weight fovea and periphery.
            loss = peripheral_loss + self.fovea_weight * foveal_loss
        else:
            loss = self.metameric_loss_stats(
                image_stats, self.target_stats, gaze)
        return loss
    else:
        raise Exception("Target of incorrect type")

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, n_pyramid_levels=5, mode='quadratic', n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False, use_fullres_l0=False, equi=False)

Parameters:

  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • n_pyramid_levels
                        Number of levels of the steerable pyramid. Note that the image is padded
                        so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                        too high will slow down the calculation a lot.
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • n_orientations
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                        Increasing this will increase runtime.
    
  • use_l2_foveal_loss
                        If true, for all the pixels that have pooling size 1 pixel in the 
                        largest scale will use direct L2 against target rather than pooling over pyramid levels.
                        In practice this gives better results when the loss is used for holography.
    
  • fovea_weight
                        A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    
  • use_radial_weight
                        If True, will apply a radial weighting when calculating the difference between
                        the source and target stats maps. This weights stats closer to the fovea more than those
                        further away.
    
  • use_fullres_l0
                        If true, stats for the lowpass residual are replaced with blurred versions
                        of the full-resolution source and target images.
    
  • equi
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/metameric_loss.py
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
             real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
             n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
             use_fullres_l0=False, equi=False):
    """
    Parameters
    ----------

    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    n_pyramid_levels        : int 
                                Number of levels of the steerable pyramid. Note that the image is padded
                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                too high will slow down the calculation a lot.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    n_orientations          : int 
                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                Increasing this will increase runtime.
    use_l2_foveal_loss      : bool 
                                If true, for all the pixels that have pooling size 1 pixel in the 
                                largest scale will use direct L2 against target rather than pooling over pyramid levels.
                                In practice this gives better results when the loss is used for holography.
    fovea_weight            : float 
                                A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    use_radial_weight       : bool 
                                If True, will apply a radial weighting when calculating the difference between
                                the source and target stats maps. This weights stats closer to the fovea more than those
                                further away.
    use_fullres_l0          : bool 
                                If true, stats for the lowpass residual are replaced with blurred versions
                                of the full-resolution source and target images.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.device = device
    self.pyramid_maker = None
    self.alpha = alpha
    self.real_image_width = real_image_width
    self.real_viewing_distance = real_viewing_distance
    self.blurs = None
    self.n_pyramid_levels = n_pyramid_levels
    self.n_orientations = n_orientations
    self.mode = mode
    self.use_l2_foveal_loss = use_l2_foveal_loss
    self.fovea_weight = fovea_weight
    self.use_radial_weight = use_radial_weight
    self.use_fullres_l0 = use_fullres_l0
    self.equi = equi
    if self.use_fullres_l0 and self.use_l2_foveal_loss:
        raise Exception(
            "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")

MetamericLossUniform

Measures metameric loss between a given image and a metamer of the given target image. This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.

Source code in odak/learn/perception/metameric_loss_uniform.py
class MetamericLossUniform():
    """
    Measures metameric loss between a given image and a metamer of the given target image.
    This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.
    """

    def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
        """

        Parameters
        ----------
        pooling_size            : int
                                  Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
        n_pyramid_levels        : int 
                                  Number of levels of the steerable pyramid. Note that the image is padded
                                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                  too high will slow down the calculation a lot.
        n_orientations          : int 
                                  Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                  Increasing this will increase runtime.

        """
        self.target = None
        self.device = device
        self.pyramid_maker = None
        self.pooling_size = pooling_size
        self.n_pyramid_levels = n_pyramid_levels
        self.n_orientations = n_orientations

    def calc_statsmaps(self, image, pooling_size):

        if self.pyramid_maker is None or \
                self.pyramid_maker.device != self.device or \
                len(self.pyramid_maker.band_filters) != self.n_orientations or\
                self.pyramid_maker.filt_h0.size(0) != image.size(1):
            self.pyramid_maker = SpatialSteerablePyramid(
                use_bilinear_downup=False, n_channels=image.size(1),
                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)


        def find_stats(image_pyr_level, pooling_size):
            image_means = uniform_blur(image_pyr_level, pooling_size)
            image_meansq = uniform_blur(image_pyr_level*image_pyr_level, pooling_size)
            image_vars = image_meansq - (image_means*image_means)
            image_vars[image_vars < 1e-7] = 1e-7
            image_std = torch.sqrt(image_vars)
            if torch.any(torch.isnan(image_means)):
                print(image_means)
                raise Exception("NaN in image means!")
            if torch.any(torch.isnan(image_std)):
                print(image_std)
                raise Exception("NaN in image stdevs!")
            return image_means, image_std

        output_stats = []
        image_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)
        curr_pooling_size = pooling_size
        means, variances = find_stats(image_pyramid[0]['h'], curr_pooling_size)
        output_stats.append(means)
        output_stats.append(variances)

        for l in range(0, len(image_pyramid)-1):
            for o in range(len(image_pyramid[l]['b'])):
                means, variances = find_stats(
                    image_pyramid[l]['b'][o], curr_pooling_size)
                output_stats.append(means)
                output_stats.append(variances)
            curr_pooling_size /= 2

        output_stats.append(image_pyramid[-1]["l"])
        return output_stats

    def metameric_loss_stats(self, statsmap_a, statsmap_b):
        loss = 0.0
        for a, b in zip(statsmap_a, statsmap_b):
            loss += torch.nn.MSELoss()(a, b)
        loss /= len(statsmap_a)
        return loss

    def visualise_loss_map(self, image_stats):
        loss_map = torch.zeros(image_stats[0].size()[-2:])
        for i in range(len(image_stats)):
            stats = image_stats[i]
            target_stats = self.target_stats[i]
            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
            loss_map += stat_mse_map[0, 0, ...]
        self.loss_map = loss_map

    def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
        """ 
        Calculates the Metameric Loss.

        Parameters
        ----------
        image               : torch.tensor
                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target              : torch.tensor
                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        image_colorspace    : str
                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
                                accepted values: RGB, YCrCb.
        visualise_loss      : bool
                                Shows a heatmap indicating which parts of the image contributed most to the loss. 

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamericLossUniform", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
        # If input is RGB, convert to YCrCb.
        if image.size(1) == 3 and image_colorspace == "RGB":
            image = rgb_2_ycrcb(image)
            target = rgb_2_ycrcb(target)
        if self.target is None:
            self.target = torch.zeros(target.shape).to(target.device)
        if type(target) == type(self.target):
            if not torch.all(torch.eq(target, self.target)):
                self.target = target.detach().clone()
                self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
                self.target = target.detach().clone()
            image_stats = self.calc_statsmaps(image, self.pooling_size)

            if visualise_loss:
                self.visualise_loss_map(image_stats)
            loss = self.metameric_loss_stats(
                image_stats, self.target_stats)
            return loss
        else:
            raise Exception("Target of incorrect type")

    def gen_metamer(self, image):
        """ 
        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
        This function can be used on its own to generate a metamer for a desired image.

        Parameters
        ----------
        image   : torch.tensor
                  Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)

        Returns
        -------
        metamer : torch.tensor
                  The generated metamer image
        """
        image = rgb_2_ycrcb(image)
        image_size = image.size()
        image = pad_image_for_pyramid(image, self.n_pyramid_levels)

        target_stats = self.calc_statsmaps(
            image, self.pooling_size)
        target_means = target_stats[::2]
        target_stdevs = target_stats[1::2]
        torch.manual_seed(0)
        noise_image = torch.rand_like(image)
        noise_pyramid = self.pyramid_maker.construct_pyramid(
            noise_image, self.n_pyramid_levels)
        input_pyramid = self.pyramid_maker.construct_pyramid(
            image, self.n_pyramid_levels)

        def match_level(input_level, target_mean, target_std):
            level = input_level.clone()
            level -= torch.mean(level)
            input_std = torch.sqrt(torch.mean(level * level))
            eps = 1e-6
            # Safeguard against divide by zero
            input_std[input_std < eps] = eps
            level /= input_std
            level *= target_std
            level += target_mean
            return level

        nbands = len(noise_pyramid[0]["b"])
        noise_pyramid[0]["h"] = match_level(
            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
        for l in range(len(noise_pyramid)-1):
            for b in range(nbands):
                noise_pyramid[l]["b"][b] = match_level(
                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

        metamer = self.pyramid_maker.reconstruct_from_pyramid(
            noise_pyramid)
        metamer = ycrcb_2_rgb(metamer)
        # Crop to remove any padding
        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
        return metamer

    def to(self, device):
        self.device = device
        return self

__call__(image, target, image_colorspace='RGB', visualise_loss=False)

Calculates the Metameric Loss.

Parameters:

  • image
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • image_colorspace
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
                    accepted values: RGB, YCrCb.
    
  • visualise_loss
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metameric_loss_uniform.py
def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
    """ 
    Calculates the Metameric Loss.

    Parameters
    ----------
    image               : torch.tensor
                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target              : torch.tensor
                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    image_colorspace    : str
                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
                            accepted values: RGB, YCrCb.
    visualise_loss      : bool
                            Shows a heatmap indicating which parts of the image contributed most to the loss. 

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamericLossUniform", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
    # If input is RGB, convert to YCrCb.
    if image.size(1) == 3 and image_colorspace == "RGB":
        image = rgb_2_ycrcb(image)
        target = rgb_2_ycrcb(target)
    if self.target is None:
        self.target = torch.zeros(target.shape).to(target.device)
    if type(target) == type(self.target):
        if not torch.all(torch.eq(target, self.target)):
            self.target = target.detach().clone()
            self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
            self.target = target.detach().clone()
        image_stats = self.calc_statsmaps(image, self.pooling_size)

        if visualise_loss:
            self.visualise_loss_map(image_stats)
        loss = self.metameric_loss_stats(
            image_stats, self.target_stats)
        return loss
    else:
        raise Exception("Target of incorrect type")

__init__(device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2)

Parameters:

  • pooling_size
                      Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    
  • n_pyramid_levels
                      Number of levels of the steerable pyramid. Note that the image is padded
                      so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                      too high will slow down the calculation a lot.
    
  • n_orientations
                      Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                      Increasing this will increase runtime.
    
Source code in odak/learn/perception/metameric_loss_uniform.py
def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
    """

    Parameters
    ----------
    pooling_size            : int
                              Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    n_pyramid_levels        : int 
                              Number of levels of the steerable pyramid. Note that the image is padded
                              so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                              too high will slow down the calculation a lot.
    n_orientations          : int 
                              Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                              Increasing this will increase runtime.

    """
    self.target = None
    self.device = device
    self.pyramid_maker = None
    self.pooling_size = pooling_size
    self.n_pyramid_levels = n_pyramid_levels
    self.n_orientations = n_orientations

gen_metamer(image)

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image
      Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    

Returns:

  • metamer ( tensor ) –

    The generated metamer image

Source code in odak/learn/perception/metameric_loss_uniform.py
def gen_metamer(self, image):
    """ 
    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
    This function can be used on its own to generate a metamer for a desired image.

    Parameters
    ----------
    image   : torch.tensor
              Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)

    Returns
    -------
    metamer : torch.tensor
              The generated metamer image
    """
    image = rgb_2_ycrcb(image)
    image_size = image.size()
    image = pad_image_for_pyramid(image, self.n_pyramid_levels)

    target_stats = self.calc_statsmaps(
        image, self.pooling_size)
    target_means = target_stats[::2]
    target_stdevs = target_stats[1::2]
    torch.manual_seed(0)
    noise_image = torch.rand_like(image)
    noise_pyramid = self.pyramid_maker.construct_pyramid(
        noise_image, self.n_pyramid_levels)
    input_pyramid = self.pyramid_maker.construct_pyramid(
        image, self.n_pyramid_levels)

    def match_level(input_level, target_mean, target_std):
        level = input_level.clone()
        level -= torch.mean(level)
        input_std = torch.sqrt(torch.mean(level * level))
        eps = 1e-6
        # Safeguard against divide by zero
        input_std[input_std < eps] = eps
        level /= input_std
        level *= target_std
        level += target_mean
        return level

    nbands = len(noise_pyramid[0]["b"])
    noise_pyramid[0]["h"] = match_level(
        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
    for l in range(len(noise_pyramid)-1):
        for b in range(nbands):
            noise_pyramid[l]["b"][b] = match_level(
                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

    metamer = self.pyramid_maker.reconstruct_from_pyramid(
        noise_pyramid)
    metamer = ycrcb_2_rgb(metamer)
    # Crop to remove any padding
    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
    return metamer

MetamerMSELoss

The MetamerMSELoss class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

Please note this is different to MetamericLoss which optimises the source image to be any metamer of the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metamer_mse_loss.py
class MetamerMSELoss():
    """ 
    The `MetamerMSELoss` class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

    Please note this is different to `MetamericLoss` which optimises the source image to be any metamer of the target image.

    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
    """


    def __init__(self, device=torch.device("cpu"),
                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
                 n_pyramid_levels=5, n_orientations=2, equi=False):
        """
        Parameters
        ----------
        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
        n_pyramid_levels        : int 
                                    Number of levels of the steerable pyramid. Note that the image is padded
                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                    too high will slow down the calculation a lot.
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        n_orientations          : int 
                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                    Increasing this will increase runtime.
        equi                    : bool
                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]
        """
        self.target = None
        self.target_metamer = None
        self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
                                            real_viewing_distance=real_viewing_distance,
                                            n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
        self.loss_func = torch.nn.MSELoss()
        self.noise = None

    def gen_metamer(self, image, gaze):
        """ 
        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
        This function can be used on its own to generate a metamer for a desired image.

        Parameters
        ----------
        image   : torch.tensor
                Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
        gaze    : list
                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        metamer : torch.tensor
                The generated metamer image
        """
        image = rgb_2_ycrcb(image)
        image_size = image.size()
        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)

        target_stats = self.metameric_loss.calc_statsmaps(
            image, gaze=gaze, alpha=self.metameric_loss.alpha)
        target_means = target_stats[::2]
        target_stdevs = target_stats[1::2]
        if self.noise is None or self.noise.size() != image.size():
            torch.manual_seed(0)
            noise_image = torch.rand_like(image)
        noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
            noise_image, self.metameric_loss.n_pyramid_levels)
        input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
            image, self.metameric_loss.n_pyramid_levels)

        def match_level(input_level, target_mean, target_std):
            level = input_level.clone()
            level -= torch.mean(level)
            input_std = torch.sqrt(torch.mean(level * level))
            eps = 1e-6
            # Safeguard against divide by zero
            input_std[input_std < eps] = eps
            level /= input_std
            level *= target_std
            level += target_mean
            return level

        nbands = len(noise_pyramid[0]["b"])
        noise_pyramid[0]["h"] = match_level(
            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
        for l in range(len(noise_pyramid)-1):
            for b in range(nbands):
                noise_pyramid[l]["b"][b] = match_level(
                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

        metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
            noise_pyramid)
        metamer = ycrcb_2_rgb(metamer)
        # Crop to remove any padding
        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
        return metamer

    def __call__(self, image, target, gaze=[0.5, 0.5]):
        """ 
        Calculates the Metamer MSE Loss.

        Parameters
        ----------
        image   : torch.tensor
                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        target  : torch.tensor
                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
        gaze    : list
                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

        Returns
        -------

        loss                : torch.tensor
                                The computed loss.
        """
        check_loss_inputs("MetamerMSELoss", image, target)
        # Pad image and target if necessary
        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
        target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)

        if target is not self.target or self.target is None:
            self.target_metamer = self.gen_metamer(target, gaze)
            self.target = target

        return self.loss_func(image, self.target_metamer)

    def to(self, device):
        self.metameric_loss = self.metameric_loss.to(device)
        return self

__call__(image, target, gaze=[0.5, 0.5])

Calculates the Metamer MSE Loss.

Parameters:

  • image
    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • target
    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • loss ( tensor ) –

    The computed loss.

Source code in odak/learn/perception/metamer_mse_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):
    """ 
    Calculates the Metamer MSE Loss.

    Parameters
    ----------
    image   : torch.tensor
            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    target  : torch.tensor
            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    gaze    : list
            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    loss                : torch.tensor
                            The computed loss.
    """
    check_loss_inputs("MetamerMSELoss", image, target)
    # Pad image and target if necessary
    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
    target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)

    if target is not self.target or self.target is None:
        self.target_metamer = self.gen_metamer(target, gaze)
        self.target = target

    return self.loss_func(image, self.target_metamer)

__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', n_pyramid_levels=5, n_orientations=2, equi=False)

Parameters:

  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
    
  • n_pyramid_levels
                        Number of levels of the steerable pyramid. Note that the image is padded
                        so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                        too high will slow down the calculation a lot.
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • n_orientations
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                        Increasing this will increase runtime.
    
  • equi
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The gaze argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    
Source code in odak/learn/perception/metamer_mse_loss.py
def __init__(self, device=torch.device("cpu"),
             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
             n_pyramid_levels=5, n_orientations=2, equi=False):
    """
    Parameters
    ----------
    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
    n_pyramid_levels        : int 
                                Number of levels of the steerable pyramid. Note that the image is padded
                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
                                too high will slow down the calculation a lot.
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    n_orientations          : int 
                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
                                Increasing this will increase runtime.
    equi                    : bool
                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The gaze argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]
    """
    self.target = None
    self.target_metamer = None
    self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
                                        real_viewing_distance=real_viewing_distance,
                                        n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
    self.loss_func = torch.nn.MSELoss()
    self.noise = None

gen_metamer(image, gaze)

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image
    Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    
  • gaze
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    

Returns:

  • metamer ( tensor ) –

    The generated metamer image

Source code in odak/learn/perception/metamer_mse_loss.py
def gen_metamer(self, image, gaze):
    """ 
    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
    This function can be used on its own to generate a metamer for a desired image.

    Parameters
    ----------
    image   : torch.tensor
            Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    gaze    : list
            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.

    Returns
    -------

    metamer : torch.tensor
            The generated metamer image
    """
    image = rgb_2_ycrcb(image)
    image_size = image.size()
    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)

    target_stats = self.metameric_loss.calc_statsmaps(
        image, gaze=gaze, alpha=self.metameric_loss.alpha)
    target_means = target_stats[::2]
    target_stdevs = target_stats[1::2]
    if self.noise is None or self.noise.size() != image.size():
        torch.manual_seed(0)
        noise_image = torch.rand_like(image)
    noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
        noise_image, self.metameric_loss.n_pyramid_levels)
    input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
        image, self.metameric_loss.n_pyramid_levels)

    def match_level(input_level, target_mean, target_std):
        level = input_level.clone()
        level -= torch.mean(level)
        input_std = torch.sqrt(torch.mean(level * level))
        eps = 1e-6
        # Safeguard against divide by zero
        input_std[input_std < eps] = eps
        level /= input_std
        level *= target_std
        level += target_mean
        return level

    nbands = len(noise_pyramid[0]["b"])
    noise_pyramid[0]["h"] = match_level(
        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
    for l in range(len(noise_pyramid)-1):
        for b in range(nbands):
            noise_pyramid[l]["b"][b] = match_level(
                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]

    metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
        noise_pyramid)
    metamer = ycrcb_2_rgb(metamer)
    # Crop to remove any padding
    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
    return metamer

RadiallyVaryingBlur

The RadiallyVaryingBlur class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given alpha parameter value. For more information on how the pooling sizes are computed, please see link coming soon.

The blur is accelerated by generating and sampling from MIP maps of the input image.

This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

Source code in odak/learn/perception/radially_varying_blur.py
class RadiallyVaryingBlur():
    """ 

    The `RadiallyVaryingBlur` class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given `alpha` parameter value. For more information on how the pooling sizes are computed, please see [link coming soon]().

    The blur is accelerated by generating and sampling from MIP maps of the input image.

    This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

    If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

    """

    def __init__(self):
        self.lod_map = None
        self.equi = None

    def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
        """
        Apply the radially varying blur to an image.

        Parameters
        ----------

        image                   : torch.tensor
                                    The image to blur, in NCHW format.
        alpha                   : float
                                    parameter controlling foveation - larger values mean bigger pooling regions.
        real_image_width        : float 
                                    The real width of the image as displayed to the user.
                                    Units don't matter as long as they are the same as for real_viewing_distance.
                                    Ignored in equirectangular mode (equi==True)
        real_viewing_distance   : float 
                                    The real distance of the observer's eyes to the image plane.
                                    Units don't matter as long as they are the same as for real_image_width.
                                    Ignored in equirectangular mode (equi==True)
        centre                  : tuple of floats
                                    The centre of the radially varying blur (the gaze location).
                                    Should be a tuple of floats containing normalised image coordinates in range [0,1]
                                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
        mode                    : str 
                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                    as you move away from the fovea. We got best results with "quadratic".
        equi                    : bool
                                    If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                    The centre argument is instead interpreted as gaze angles, and should be in the range
                                    [-pi,pi]x[-pi/2,pi]

        Returns
        -------

        output                  : torch.tensor
                                    The blurred image
        """
        size = (image.size(-2), image.size(-1))

        # LOD map caching
        if self.lod_map is None or\
                self.size != size or\
                self.n_channels != image.size(1) or\
                self.alpha != alpha or\
                self.real_image_width != real_image_width or\
                self.real_viewing_distance != real_viewing_distance or\
                self.centre != centre or\
                self.mode != mode or\
                self.equi != equi:
            if not equi:
                self.lod_map = make_pooling_size_map_lod(
                    centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
            else:
                self.lod_map = make_equi_pooling_size_map_lod(
                    centre, (image.size(-2), image.size(-1)), alpha, mode)
            self.size = size
            self.n_channels = image.size(1)
            self.alpha = alpha
            self.real_image_width = real_image_width
            self.real_viewing_distance = real_viewing_distance
            self.centre = centre
            self.lod_map = self.lod_map.to(image.device)
            self.lod_fraction = torch.fmod(self.lod_map, 1.0)
            self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
                1, image.size(1), 1, 1)
            self.mode = mode
            self.equi = equi

        if self.lod_map.device != image.device:
            self.lod_map = self.lod_map.to(image.device)
        if self.lod_fraction.device != image.device:
            self.lod_fraction = self.lod_fraction.to(image.device)

        mipmap = [image]
        while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
            mipmap.append(torch.nn.functional.interpolate(
                mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
        if mipmap[-1].size(-1) == 2:
            final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
            mipmap.append(final_mip)
        if mipmap[-1].size(-2) == 2:
            final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
            mipmap.append(final_mip)

        for l in range(len(mipmap)):
            if l == len(mipmap)-1:
                mipmap[l] = mipmap[l] * \
                    torch.ones(image.size(), device=image.device)
            else:
                for l2 in range(l-1, -1, -1):
                    mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
                        image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)

        output = torch.zeros(image.size(), device=image.device)
        for l in range(len(mipmap)):
            if l == 0:
                mask = self.lod_map < (l+1)
            elif l == len(mipmap)-1:
                mask = self.lod_map >= l
            else:
                mask = torch.logical_and(
                    self.lod_map >= l, self.lod_map < (l+1))

            if l == len(mipmap)-1:
                blended_levels = mipmap[l]
            else:
                blended_levels = (1 - self.lod_fraction) * \
                    mipmap[l] + self.lod_fraction*mipmap[l+1]
            mask = mask[None, None, ...]
            mask = mask.repeat(1, image.size(1), 1, 1)
            output[mask] = blended_levels[mask]

        return output

blur(image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode='quadratic', equi=False)

Apply the radially varying blur to an image.

Parameters:

  • image
                        The image to blur, in NCHW format.
    
  • alpha
                        parameter controlling foveation - larger values mean bigger pooling regions.
    
  • real_image_width
                        The real width of the image as displayed to the user.
                        Units don't matter as long as they are the same as for real_viewing_distance.
                        Ignored in equirectangular mode (equi==True)
    
  • real_viewing_distance
                        The real distance of the observer's eyes to the image plane.
                        Units don't matter as long as they are the same as for real_image_width.
                        Ignored in equirectangular mode (equi==True)
    
  • centre
                        The centre of the radially varying blur (the gaze location).
                        Should be a tuple of floats containing normalised image coordinates in range [0,1]
                        In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    
  • mode
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                        as you move away from the fovea. We got best results with "quadratic".
    
  • equi
                        If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                        format 360 image. The settings real_image_width and real_viewing distance are ignored.
                        The centre argument is instead interpreted as gaze angles, and should be in the range
                        [-pi,pi]x[-pi/2,pi]
    

Returns:

  • output ( tensor ) –

    The blurred image

Source code in odak/learn/perception/radially_varying_blur.py
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
    """
    Apply the radially varying blur to an image.

    Parameters
    ----------

    image                   : torch.tensor
                                The image to blur, in NCHW format.
    alpha                   : float
                                parameter controlling foveation - larger values mean bigger pooling regions.
    real_image_width        : float 
                                The real width of the image as displayed to the user.
                                Units don't matter as long as they are the same as for real_viewing_distance.
                                Ignored in equirectangular mode (equi==True)
    real_viewing_distance   : float 
                                The real distance of the observer's eyes to the image plane.
                                Units don't matter as long as they are the same as for real_image_width.
                                Ignored in equirectangular mode (equi==True)
    centre                  : tuple of floats
                                The centre of the radially varying blur (the gaze location).
                                Should be a tuple of floats containing normalised image coordinates in range [0,1]
                                In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    mode                    : str 
                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
                                as you move away from the fovea. We got best results with "quadratic".
    equi                    : bool
                                If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
                                The centre argument is instead interpreted as gaze angles, and should be in the range
                                [-pi,pi]x[-pi/2,pi]

    Returns
    -------

    output                  : torch.tensor
                                The blurred image
    """
    size = (image.size(-2), image.size(-1))

    # LOD map caching
    if self.lod_map is None or\
            self.size != size or\
            self.n_channels != image.size(1) or\
            self.alpha != alpha or\
            self.real_image_width != real_image_width or\
            self.real_viewing_distance != real_viewing_distance or\
            self.centre != centre or\
            self.mode != mode or\
            self.equi != equi:
        if not equi:
            self.lod_map = make_pooling_size_map_lod(
                centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
        else:
            self.lod_map = make_equi_pooling_size_map_lod(
                centre, (image.size(-2), image.size(-1)), alpha, mode)
        self.size = size
        self.n_channels = image.size(1)
        self.alpha = alpha
        self.real_image_width = real_image_width
        self.real_viewing_distance = real_viewing_distance
        self.centre = centre
        self.lod_map = self.lod_map.to(image.device)
        self.lod_fraction = torch.fmod(self.lod_map, 1.0)
        self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
            1, image.size(1), 1, 1)
        self.mode = mode
        self.equi = equi

    if self.lod_map.device != image.device:
        self.lod_map = self.lod_map.to(image.device)
    if self.lod_fraction.device != image.device:
        self.lod_fraction = self.lod_fraction.to(image.device)

    mipmap = [image]
    while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
        mipmap.append(torch.nn.functional.interpolate(
            mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
    if mipmap[-1].size(-1) == 2:
        final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
        mipmap.append(final_mip)
    if mipmap[-1].size(-2) == 2:
        final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
        mipmap.append(final_mip)

    for l in range(len(mipmap)):
        if l == len(mipmap)-1:
            mipmap[l] = mipmap[l] * \
                torch.ones(image.size(), device=image.device)
        else:
            for l2 in range(l-1, -1, -1):
                mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
                    image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)

    output = torch.zeros(image.size(), device=image.device)
    for l in range(len(mipmap)):
        if l == 0:
            mask = self.lod_map < (l+1)
        elif l == len(mipmap)-1:
            mask = self.lod_map >= l
        else:
            mask = torch.logical_and(
                self.lod_map >= l, self.lod_map < (l+1))

        if l == len(mipmap)-1:
            blended_levels = mipmap[l]
        else:
            blended_levels = (1 - self.lod_fraction) * \
                mipmap[l] + self.lod_fraction*mipmap[l+1]
        mask = mask[None, None, ...]
        mask = mask.repeat(1, image.size(1), 1, 1)
        output[mask] = blended_levels[mask]

    return output

SpatialSteerablePyramid

This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution) as opposed to multiplication in the Fourier domain. This has a number of optimisations over previous implementations that increase efficiency, but introduce some reconstruction error.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
class SpatialSteerablePyramid():
    """
    This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution)
    as opposed to multiplication in the Fourier domain.
    This has a number of optimisations over previous implementations that increase efficiency, but introduce some
    reconstruction error.
    """


    def __init__(self, use_bilinear_downup=True, n_channels=1,
                 filter_size=9, n_orientations=6, filter_type="full",
                 device=torch.device('cpu')):
        """
        Parameters
        ----------

        use_bilinear_downup     : bool
                                    This uses bilinear filtering when upsampling/downsampling, rather than the original approach
                                    of applying a large lowpass kernel and sampling even rows/columns
        n_channels              : int
                                    Number of channels in the input images (e.g. 3 for RGB input)
        filter_size             : int
                                    Desired size of filters (e.g. 3 will use 3x3 filters).
        n_orientations          : int
                                    Number of oriented bands in each level of the pyramid.
        filter_type             : str
                                    This can be used to select smaller filters than the original ones if desired.
                                    full: Original filter sizes
                                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
        device                  : torch.device
                                    torch device the input images will be supplied from.
        """
        self.use_bilinear_downup = use_bilinear_downup
        self.device = device

        filters = get_steerable_pyramid_filters(
            filter_size, n_orientations, filter_type)

        def make_pad(filter):
            filter_size = filter.size(-1)
            pad_amt = (filter_size-1) // 2
            return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))

        if not self.use_bilinear_downup:
            self.filt_l = filters["l"].to(device)
            self.pad_l = make_pad(self.filt_l)
        self.filt_l0 = filters["l0"].to(device)
        self.pad_l0 = make_pad(self.filt_l0)
        self.filt_h0 = filters["h0"].to(device)
        self.pad_h0 = make_pad(self.filt_h0)
        for b in range(len(filters["b"])):
            filters["b"][b] = filters["b"][b].to(device)
        self.band_filters = filters["b"]
        self.pad_b = make_pad(self.band_filters[0])

        if n_channels != 1:
            def add_channels_to_filter(filter):
                padded = torch.zeros(n_channels, n_channels, filter.size()[
                                     2], filter.size()[3]).to(device)
                for channel in range(n_channels):
                    padded[channel, channel, :, :] = filter
                return padded
            self.filt_h0 = add_channels_to_filter(self.filt_h0)
            for b in range(len(self.band_filters)):
                self.band_filters[b] = add_channels_to_filter(
                    self.band_filters[b])
            self.filt_l0 = add_channels_to_filter(self.filt_l0)
            if not self.use_bilinear_downup:
                self.filt_l = add_channels_to_filter(self.filt_l)

    def construct_pyramid(self, image, n_levels, multiple_highpass=False):
        """
        Constructs and returns a steerable pyramid for the provided image.

        Parameters
        ----------

        image               : torch.tensor
                                The input image, in NCHW format. The number of channels C should match num_channels
                                when the pyramid maker was created.
        n_levels            : int
                                Number of levels in the constructed steerable pyramid.
        multiple_highpass   : bool
                                If true, computes a highpass for each level of the pyramid.
                                These extra levels are redundant (not used for reconstruction).

        Returns
        -------

        pyramid             : list of dicts of torch.tensor
                                The computed steerable pyramid.
                                Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
                                Each level is stored as a dict, with the following keys:
                                "h" Highpass residual
                                "l" Lowpass residual
                                "b" Oriented bands (a list of torch.tensor)
        """
        pyramid = []

        # Make level 0, containing highpass, lowpass and the bands
        level0 = {}
        level0['h'] = torch.nn.functional.conv2d(
            self.pad_h0(image), self.filt_h0)
        lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
        level0['l'] = lowpass.clone()
        bands = []
        for filt_b in self.band_filters:
            bands.append(torch.nn.functional.conv2d(
                self.pad_b(lowpass), filt_b))
        level0['b'] = bands
        pyramid.append(level0)

        # Make intermediate levels
        for l in range(n_levels-2):
            level = {}
            if self.use_bilinear_downup:
                lowpass = torch.nn.functional.interpolate(
                    lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
            else:
                lowpass = torch.nn.functional.conv2d(
                    self.pad_l(lowpass), self.filt_l)
                lowpass = lowpass[:, :, ::2, ::2]
            level['l'] = lowpass.clone()
            bands = []
            for filt_b in self.band_filters:
                bands.append(torch.nn.functional.conv2d(
                    self.pad_b(lowpass), filt_b))
            level['b'] = bands
            if multiple_highpass:
                level['h'] = torch.nn.functional.conv2d(
                    self.pad_h0(lowpass), self.filt_h0)
            pyramid.append(level)

        # Make final level (lowpass residual)
        level = {}
        if self.use_bilinear_downup:
            lowpass = torch.nn.functional.interpolate(
                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
        else:
            lowpass = torch.nn.functional.conv2d(
                self.pad_l(lowpass), self.filt_l)
            lowpass = lowpass[:, :, ::2, ::2]
        level['l'] = lowpass
        pyramid.append(level)

        return pyramid

    def reconstruct_from_pyramid(self, pyramid):
        """
        Reconstructs an input image from a steerable pyramid.

        Parameters
        ----------

        pyramid : list of dicts of torch.tensor
                    The steerable pyramid.
                    Should be in the same format as output by construct_steerable_pyramid().
                    The number of channels should match num_channels when the pyramid maker was created.

        Returns
        -------

        image   : torch.tensor
                    The reconstructed image, in NCHW format.         
        """
        def upsample(image, size):
            if self.use_bilinear_downup:
                return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
            else:
                zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
                                    2]*2, image.size()[3]*2)).to(self.device)
                zeros[:, :, ::2, ::2] = image
                zeros = torch.nn.functional.conv2d(
                    self.pad_l(zeros), self.filt_l)
                return zeros

        image = pyramid[-1]['l']
        for level in reversed(pyramid[:-1]):
            image = upsample(image, level['b'][0].size()[2:])
            for b in range(len(level['b'])):
                b_filtered = torch.nn.functional.conv2d(
                    self.pad_b(level['b'][b]), -self.band_filters[b])
                image += b_filtered

        image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
        image += torch.nn.functional.conv2d(
            self.pad_h0(pyramid[0]['h']), self.filt_h0)

        return image

__init__(use_bilinear_downup=True, n_channels=1, filter_size=9, n_orientations=6, filter_type='full', device=torch.device('cpu'))

Parameters:

  • use_bilinear_downup
                        This uses bilinear filtering when upsampling/downsampling, rather than the original approach
                        of applying a large lowpass kernel and sampling even rows/columns
    
  • n_channels
                        Number of channels in the input images (e.g. 3 for RGB input)
    
  • filter_size
                        Desired size of filters (e.g. 3 will use 3x3 filters).
    
  • n_orientations
                        Number of oriented bands in each level of the pyramid.
    
  • filter_type
                        This can be used to select smaller filters than the original ones if desired.
                        full: Original filter sizes
                        cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                        trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    
  • device
                        torch device the input images will be supplied from.
    
Source code in odak/learn/perception/spatial_steerable_pyramid.py
def __init__(self, use_bilinear_downup=True, n_channels=1,
             filter_size=9, n_orientations=6, filter_type="full",
             device=torch.device('cpu')):
    """
    Parameters
    ----------

    use_bilinear_downup     : bool
                                This uses bilinear filtering when upsampling/downsampling, rather than the original approach
                                of applying a large lowpass kernel and sampling even rows/columns
    n_channels              : int
                                Number of channels in the input images (e.g. 3 for RGB input)
    filter_size             : int
                                Desired size of filters (e.g. 3 will use 3x3 filters).
    n_orientations          : int
                                Number of oriented bands in each level of the pyramid.
    filter_type             : str
                                This can be used to select smaller filters than the original ones if desired.
                                full: Original filter sizes
                                cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                                trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    device                  : torch.device
                                torch device the input images will be supplied from.
    """
    self.use_bilinear_downup = use_bilinear_downup
    self.device = device

    filters = get_steerable_pyramid_filters(
        filter_size, n_orientations, filter_type)

    def make_pad(filter):
        filter_size = filter.size(-1)
        pad_amt = (filter_size-1) // 2
        return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))

    if not self.use_bilinear_downup:
        self.filt_l = filters["l"].to(device)
        self.pad_l = make_pad(self.filt_l)
    self.filt_l0 = filters["l0"].to(device)
    self.pad_l0 = make_pad(self.filt_l0)
    self.filt_h0 = filters["h0"].to(device)
    self.pad_h0 = make_pad(self.filt_h0)
    for b in range(len(filters["b"])):
        filters["b"][b] = filters["b"][b].to(device)
    self.band_filters = filters["b"]
    self.pad_b = make_pad(self.band_filters[0])

    if n_channels != 1:
        def add_channels_to_filter(filter):
            padded = torch.zeros(n_channels, n_channels, filter.size()[
                                 2], filter.size()[3]).to(device)
            for channel in range(n_channels):
                padded[channel, channel, :, :] = filter
            return padded
        self.filt_h0 = add_channels_to_filter(self.filt_h0)
        for b in range(len(self.band_filters)):
            self.band_filters[b] = add_channels_to_filter(
                self.band_filters[b])
        self.filt_l0 = add_channels_to_filter(self.filt_l0)
        if not self.use_bilinear_downup:
            self.filt_l = add_channels_to_filter(self.filt_l)

construct_pyramid(image, n_levels, multiple_highpass=False)

Constructs and returns a steerable pyramid for the provided image.

Parameters:

  • image
                    The input image, in NCHW format. The number of channels C should match num_channels
                    when the pyramid maker was created.
    
  • n_levels
                    Number of levels in the constructed steerable pyramid.
    
  • multiple_highpass
                    If true, computes a highpass for each level of the pyramid.
                    These extra levels are redundant (not used for reconstruction).
    

Returns:

  • pyramid ( list of dicts of torch.tensor ) –

    The computed steerable pyramid. Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels. Each level is stored as a dict, with the following keys: "h" Highpass residual "l" Lowpass residual "b" Oriented bands (a list of torch.tensor)

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def construct_pyramid(self, image, n_levels, multiple_highpass=False):
    """
    Constructs and returns a steerable pyramid for the provided image.

    Parameters
    ----------

    image               : torch.tensor
                            The input image, in NCHW format. The number of channels C should match num_channels
                            when the pyramid maker was created.
    n_levels            : int
                            Number of levels in the constructed steerable pyramid.
    multiple_highpass   : bool
                            If true, computes a highpass for each level of the pyramid.
                            These extra levels are redundant (not used for reconstruction).

    Returns
    -------

    pyramid             : list of dicts of torch.tensor
                            The computed steerable pyramid.
                            Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
                            Each level is stored as a dict, with the following keys:
                            "h" Highpass residual
                            "l" Lowpass residual
                            "b" Oriented bands (a list of torch.tensor)
    """
    pyramid = []

    # Make level 0, containing highpass, lowpass and the bands
    level0 = {}
    level0['h'] = torch.nn.functional.conv2d(
        self.pad_h0(image), self.filt_h0)
    lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
    level0['l'] = lowpass.clone()
    bands = []
    for filt_b in self.band_filters:
        bands.append(torch.nn.functional.conv2d(
            self.pad_b(lowpass), filt_b))
    level0['b'] = bands
    pyramid.append(level0)

    # Make intermediate levels
    for l in range(n_levels-2):
        level = {}
        if self.use_bilinear_downup:
            lowpass = torch.nn.functional.interpolate(
                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
        else:
            lowpass = torch.nn.functional.conv2d(
                self.pad_l(lowpass), self.filt_l)
            lowpass = lowpass[:, :, ::2, ::2]
        level['l'] = lowpass.clone()
        bands = []
        for filt_b in self.band_filters:
            bands.append(torch.nn.functional.conv2d(
                self.pad_b(lowpass), filt_b))
        level['b'] = bands
        if multiple_highpass:
            level['h'] = torch.nn.functional.conv2d(
                self.pad_h0(lowpass), self.filt_h0)
        pyramid.append(level)

    # Make final level (lowpass residual)
    level = {}
    if self.use_bilinear_downup:
        lowpass = torch.nn.functional.interpolate(
            lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
    else:
        lowpass = torch.nn.functional.conv2d(
            self.pad_l(lowpass), self.filt_l)
        lowpass = lowpass[:, :, ::2, ::2]
    level['l'] = lowpass
    pyramid.append(level)

    return pyramid

reconstruct_from_pyramid(pyramid)

Reconstructs an input image from a steerable pyramid.

Parameters:

  • pyramid (list of dicts of torch.tensor) –
        The steerable pyramid.
        Should be in the same format as output by construct_steerable_pyramid().
        The number of channels should match num_channels when the pyramid maker was created.
    

Returns:

  • image ( tensor ) –

    The reconstructed image, in NCHW format.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def reconstruct_from_pyramid(self, pyramid):
    """
    Reconstructs an input image from a steerable pyramid.

    Parameters
    ----------

    pyramid : list of dicts of torch.tensor
                The steerable pyramid.
                Should be in the same format as output by construct_steerable_pyramid().
                The number of channels should match num_channels when the pyramid maker was created.

    Returns
    -------

    image   : torch.tensor
                The reconstructed image, in NCHW format.         
    """
    def upsample(image, size):
        if self.use_bilinear_downup:
            return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
        else:
            zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
                                2]*2, image.size()[3]*2)).to(self.device)
            zeros[:, :, ::2, ::2] = image
            zeros = torch.nn.functional.conv2d(
                self.pad_l(zeros), self.filt_l)
            return zeros

    image = pyramid[-1]['l']
    for level in reversed(pyramid[:-1]):
        image = upsample(image, level['b'][0].size()[2:])
        for b in range(len(level['b'])):
            b_filtered = torch.nn.functional.conv2d(
                self.pad_b(level['b'][b]), -self.band_filters[b])
            image += b_filtered

    image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
    image += torch.nn.functional.conv2d(
        self.pad_h0(pyramid[0]['h']), self.filt_h0)

    return image

pad_image_for_pyramid(image, n_pyramid_levels)

Pads an image to the extent necessary to compute a steerable pyramid of the input image. This involves padding so both height and width are divisible by 2**n_pyramid_levels. Uses reflection padding.

Parameters:

  • image

    Image to pad, in NCHW format

  • n_pyramid_levels

    Number of levels in the pyramid you plan to construct.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def pad_image_for_pyramid(image, n_pyramid_levels):
    """
    Pads an image to the extent necessary to compute a steerable pyramid of the input image.
    This involves padding so both height and width are divisible by 2**n_pyramid_levels.
    Uses reflection padding.

    Parameters
    ----------

    image: torch.tensor
        Image to pad, in NCHW format
    n_pyramid_levels: int
        Number of levels in the pyramid you plan to construct.
    """
    min_divisor = 2 ** n_pyramid_levels
    height = image.size(2)
    width = image.size(3)
    required_height = math.ceil(height / min_divisor) * min_divisor
    required_width = math.ceil(width / min_divisor) * min_divisor
    if required_height > height or required_width > width:
        # We need to pad!
        pad = torch.nn.ReflectionPad2d(
            (0, 0, required_height-height, required_width-width))
        return pad(image)
    return image

crop_steerable_pyramid_filters(filters, size)

Given original 9x9 NYU filters, this crops them to the desired size. The size must be an odd number >= 3 Note this only crops the h0, l0 and band filters (not the l downsampling filter)

Parameters:

  • filters
            Filters to crop (should in format used by get_steerable_pyramid_filters.)
    
  • size
            Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.
    
  • Returns
  • filters
            The cropped filters.
    
Source code in odak/learn/perception/steerable_pyramid_filters.py
def crop_steerable_pyramid_filters(filters, size):
    """
    Given original 9x9 NYU filters, this crops them to the desired size.
    The size must be an odd number >= 3
    Note this only crops the h0, l0 and band filters (not the l downsampling filter)

    Parameters
    ----------
    filters     : dict of torch.tensor
                    Filters to crop (should in format used by get_steerable_pyramid_filters.)
    size        : int
                    Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.

    Returns
    =======
    filters     : dict of torch.tensor
                    The cropped filters.
    """
    assert(size >= 3)
    assert(size % 2 == 1)
    r = (size-1) // 2

    def crop_filter(filter, r, normalise=True):
        r2 = (filter.size(-1)-1)//2
        filter = filter[:, :, r2-r:r2+r+1, r2-r:r2+r+1]
        if normalise:
            filter -= torch.sum(filter)
        return filter

    filters["h0"] = crop_filter(filters["h0"], r, normalise=False)
    sum_l = torch.sum(filters["l"])
    filters["l"] = crop_filter(filters["l"], 6, normalise=False)
    filters["l"] *= sum_l / torch.sum(filters["l"])
    sum_l0 = torch.sum(filters["l0"])
    filters["l0"] = crop_filter(filters["l0"], 2, normalise=False)
    filters["l0"] *= sum_l0 / torch.sum(filters["l0"])
    for b in range(len(filters["b"])):
        filters["b"][b] = crop_filter(filters["b"][b], r, normalise=True)
    return filters

get_steerable_pyramid_filters(size, n_orientations, filter_type)

This returns filters for a real-valued steerable pyramid.

Parameters:

  • size
                Width of the filters (e.g. 3 will return 3x3 filters)
    
  • n_orientations
                Number of oriented band filters
    
  • filter_type
                This can be used to select between the original NYU filters and cropped or trained alternatives.
                full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py
                cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    
  • Returns
  • filters
                The steerable pyramid filters. Returned as a dict with the following keys:
                "l" The lowpass downsampling filter
                "l0" The lowpass residual filter
                "h0" The highpass residual filter
                "b" The band filters (a list of torch.tensor filters, one for each orientation).
    
Source code in odak/learn/perception/steerable_pyramid_filters.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def get_steerable_pyramid_filters(size, n_orientations, filter_type):
    """
    This returns filters for a real-valued steerable pyramid.

    Parameters
    ----------

    size            : int
                        Width of the filters (e.g. 3 will return 3x3 filters)
    n_orientations  : int
                        Number of oriented band filters
    filter_type     :  str
                        This can be used to select between the original NYU filters and cropped or trained alternatives.
                        full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py
                        cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
                        trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.

    Returns
    =======

    filters         : dict of torch.tensor
                        The steerable pyramid filters. Returned as a dict with the following keys:
                        "l" The lowpass downsampling filter
                        "l0" The lowpass residual filter
                        "h0" The highpass residual filter
                        "b" The band filters (a list of torch.tensor filters, one for each orientation).
    """

    if filter_type != "full" and filter_type != "cropped" and filter_type != "trained":
        raise Exception(
            "Unknown filter type %s! Only filter types are full, cropped or trained." % filter_type)

    filters = {}
    if n_orientations == 1:
        filters["l"] = torch.tensor([
            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -
                1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04],
            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -
                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],
            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -
                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],
            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,
                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],
            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,
                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],
            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,
                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],
            [-1.871920e-03, -6.948900e-03, -3.781600e-03, 2.449600e-02, 7.624674e-02, 1.348999e-01, 1.576508e-01,
                1.348999e-01, 7.624674e-02, 2.449600e-02, -3.781600e-03, -6.948900e-03, -1.871920e-03],
            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,
                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],
            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,
                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],
            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,
                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],
            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -
                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],
            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -
                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],
            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04]]
        ).reshape(1, 1, 13, 13)
        filters["l0"] = torch.tensor([
            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -
                3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04],
            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -
                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],
            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,
                6.441488e-02, -1.344160e-02, -3.725800e-04],
            [-3.743860e-03, -7.563200e-03, 1.524935e-01, 3.153017e-01,
                1.524935e-01, -7.563200e-03, -3.743860e-03],
            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,
                6.441488e-02, -1.344160e-02, -3.725800e-04],
            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -
                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],
            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04]]
        ).reshape(1, 1, 7, 7)
        filters["h0"] = torch.tensor([
            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -
                2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04],
            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -
                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],
            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -
                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],
            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -
                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],
            [-2.406600e-04, -3.732100e-04, -2.420138e-02, -9.623594e-02,
                8.554893e-01, -9.623594e-02, -2.420138e-02, -3.732100e-04, -2.406600e-04],
            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -
                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],
            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -
                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],
            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -
                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],
            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04]]
        ).reshape(1, 1, 9, 9)
        filters["b"] = []
        filters["b"].append(torch.tensor([
            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -
            1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05,
            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -
            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,
            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -
            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,
            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -
            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,
            -1.009473e-02, -9.091950e-03, -3.487008e-02, -3.158605e-02, 9.464195e-01, -
            3.158605e-02, -3.487008e-02, -9.091950e-03, -1.009473e-02,
            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -
            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,
            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -
            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,
            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -
            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,
            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05]
        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))

    elif n_orientations == 2:
        filters["l"] = torch.tensor(
            [[-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05],
             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,
                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],
             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,
                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],
             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -
                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],
             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -
                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],
             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,
                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],
             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,
                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],
             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,
                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],
             [1.262000e-03, 5.589400e-03, 5.538200e-04, -9.637220e-03, -1.648568e-02, 1.438512e-02, 9.058014e-02, 1.773651e-01, 2.120374e-01,
                 1.773651e-01, 9.058014e-02, 1.438512e-02, -1.648568e-02, -9.637220e-03, 5.538200e-04, 5.589400e-03, 1.262000e-03],
             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,
                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],
             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,
                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],
             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,
                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],
             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -
                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],
             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -
                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],
             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,
                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],
             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,
                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],
             [-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05]]
        ).reshape(1, 1, 17, 17)
        filters["l0"] = torch.tensor(
            [[-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05],
             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,
                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],
             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -
                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],
             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,
                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],
             [2.524010e-03, 1.107620e-03, -3.297137e-02, 1.811603e-01, 4.376250e-01,
                 1.811603e-01, -3.297137e-02, 1.107620e-03, 2.524010e-03],
             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,
                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],
             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -
                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],
             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,
                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],
             [-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05]]
        ).reshape(1, 1, 9, 9)
        filters["h0"] = torch.tensor(
            [[-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04],
             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,
                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],
             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -
                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],
             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -
                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],
             [-1.166810e-03, 1.098012e-02, -5.897780e-03, -1.562827e-01,
                 7.282593e-01, -1.562827e-01, -5.897780e-03, 1.098012e-02, -1.166810e-03],
             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -
                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],
             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -
                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],
             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,
                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],
             [-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04]]
        ).reshape(1, 1, 9, 9)
        filters["b"] = []
        filters["b"].append(torch.tensor(
            [6.125880e-03, -8.052600e-03, -2.103714e-02, -1.536890e-02, -1.851466e-02, -1.536890e-02, -2.103714e-02, -8.052600e-03, 6.125880e-03,
             -1.287416e-02, -9.611520e-03, 1.023569e-02, 6.009450e-03, 1.872620e-03, 6.009450e-03, 1.023569e-02, -
             9.611520e-03, -1.287416e-02,
             -5.641530e-03, 4.168400e-03, -2.382180e-02, -5.375324e-02, -
             2.076086e-02, -5.375324e-02, -2.382180e-02, 4.168400e-03, -5.641530e-03,
             -8.957260e-03, -1.751170e-03, -1.836909e-02, 1.265655e-01, 2.996168e-01, 1.265655e-01, -
             1.836909e-02, -1.751170e-03, -8.957260e-03,
             0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
             8.957260e-03, 1.751170e-03, 1.836909e-02, -1.265655e-01, -
             2.996168e-01, -1.265655e-01, 1.836909e-02, 1.751170e-03, 8.957260e-03,
             5.641530e-03, -4.168400e-03, 2.382180e-02, 5.375324e-02, 2.076086e-02, 5.375324e-02, 2.382180e-02, -
             4.168400e-03, 5.641530e-03,
             1.287416e-02, 9.611520e-03, -1.023569e-02, -6.009450e-03, -
             1.872620e-03, -6.009450e-03, -1.023569e-02, 9.611520e-03, 1.287416e-02,
             -6.125880e-03, 8.052600e-03, 2.103714e-02, 1.536890e-02, 1.851466e-02, 1.536890e-02, 2.103714e-02, 8.052600e-03, -6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor(
            [-6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03,
             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -
             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,
             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -
             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,
             1.536890e-02, -6.009450e-03, 5.375324e-02, -
             1.265655e-01, 0.000000e+00, 1.265655e-01, -
             5.375324e-02, 6.009450e-03, -1.536890e-02,
             1.851466e-02, -1.872620e-03, 2.076086e-02, -
             2.996168e-01, 0.000000e+00, 2.996168e-01, -
             2.076086e-02, 1.872620e-03, -1.851466e-02,
             1.536890e-02, -6.009450e-03, 5.375324e-02, -
             1.265655e-01, 0.000000e+00, 1.265655e-01, -
             5.375324e-02, 6.009450e-03, -1.536890e-02,
             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -
             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,
             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -
             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,
             -6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))

    elif n_orientations == 4:
        filters["l"] = torch.tensor([
            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4,
                1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5],
            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,
                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],
            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,
                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],
            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -
                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],
            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -
                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],
            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,
                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],
            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,
                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],
            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,
                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],
            [1.2619999470E-3, 5.5893999524E-3, 5.5381999118E-4, -9.6372198313E-3, -1.6485679895E-2, 1.4385120012E-2, 9.0580143034E-2, 0.1773651242, 0.2120374441,
                0.1773651242, 9.0580143034E-2, 1.4385120012E-2, -1.6485679895E-2, -9.6372198313E-3, 5.5381999118E-4, 5.5893999524E-3, 1.2619999470E-3],
            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,
                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],
            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,
                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],
            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,
                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],
            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -
                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],
            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -
                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],
            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,
                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],
            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,
                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],
            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4, 1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5]]
        ).reshape(1, 1, 17, 17)
        filters["l0"] = torch.tensor([
            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4,
                2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5],
            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,
                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],
            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -
                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],
            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,
                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],
            [2.5240099058E-3, 1.1076199589E-3, -3.2971370965E-2, 0.1811603010, 0.4376249909,
                0.1811603010, -3.2971370965E-2, 1.1076199589E-3, 2.5240099058E-3],
            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,
                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],
            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -
                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],
            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,
                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],
            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4, 2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5]]
        ).reshape(1, 1, 9, 9)
        filters["h0"] = torch.tensor([
            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3,
                2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4],
            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,
                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],
            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -
                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],
            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -
                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],
            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,
                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],
            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -
                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],
            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -
                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],
            [2.0898699295E-3, 1.9755100366E-3, -8.3368999185E-4, -5.1014497876E-3, 7.3032598011E-3, -9.3812197447E-3, -0.1554141641,
                0.7303866148, -0.1554141641, -9.3812197447E-3, 7.3032598011E-3, -5.1014497876E-3, -8.3368999185E-4, 1.9755100366E-3, 2.0898699295E-3],
            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -
                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],
            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -
                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],
            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,
                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],
            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -
                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],
            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -
                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],
            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,
                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],
            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3, 2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4]]
        ).reshape(1, 1, 15, 15)
        filters["b"] = []
        filters["b"].append(torch.tensor(
            [-8.1125000725E-4, 4.4451598078E-3, 1.2316980399E-2, 1.3955879956E-2,  1.4179450460E-2, 1.3955879956E-2, 1.2316980399E-2, 4.4451598078E-3, -8.1125000725E-4,
             3.9103501476E-3, 4.4565401040E-3, -5.8724298142E-3, -2.8760801069E-3, 8.5267601535E-3, -
             2.8760801069E-3, -5.8724298142E-3, 4.4565401040E-3, 3.9103501476E-3,
             1.3462699717E-3, -3.7740699481E-3, 8.2581602037E-3, 3.9442278445E-2, 5.3605638444E-2, 3.9442278445E-2, 8.2581602037E-3, -
             3.7740699481E-3, 1.3462699717E-3,
             7.4700999539E-4, -3.6522001028E-4, -2.2522680461E-2, -0.1105690673, -
             0.1768419296, -0.1105690673, -2.2522680461E-2, -3.6522001028E-4, 7.4700999539E-4,
             0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
             -7.4700999539E-4, 3.6522001028E-4, 2.2522680461E-2, 0.1105690673, 0.1768419296, 0.1105690673, 2.2522680461E-2, 3.6522001028E-4, -7.4700999539E-4,
             -1.3462699717E-3, 3.7740699481E-3, -8.2581602037E-3, -3.9442278445E-2, -
             5.3605638444E-2, -3.9442278445E-2, -
             8.2581602037E-3, 3.7740699481E-3, -1.3462699717E-3,
             -3.9103501476E-3, -4.4565401040E-3, 5.8724298142E-3, 2.8760801069E-3, -
             8.5267601535E-3, 2.8760801069E-3, 5.8724298142E-3, -
             4.4565401040E-3, -3.9103501476E-3,
             8.1125000725E-4, -4.4451598078E-3, -1.2316980399E-2, -1.3955879956E-2, -1.4179450460E-2, -1.3955879956E-2, -1.2316980399E-2, -4.4451598078E-3, 8.1125000725E-4]
        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor(
            [0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3,
             8.2846998703E-4, 0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -
             2.0865600090E-3, 2.3298799060E-3, -
             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,
             5.7109999034E-5, 9.7479997203E-4, 0.0000000000, -1.2145539746E-2, -
             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -
             4.4814897701E-3, 1.4807609841E-2,
             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -
             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,
             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -
             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,
             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, 0.0000000000, -
             1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,
             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -
             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -
             9.7479997203E-4, -5.7109999034E-5,
             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -
             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, 0.0000000000, -8.2846998703E-4,
             3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000]
        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor(
            [8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4,
             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -
             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,
             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -
             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,
             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -
             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,
             -1.4179450460E-2, -8.5267601535E-3, -5.3605638444E-2, 0.1768419296, 0.0000000000, -
             0.1768419296, 5.3605638444E-2, 8.5267601535E-3, 1.4179450460E-2,
             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -
             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,
             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -
             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,
             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -
             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,
             8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4]
        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor(
            [3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000,
             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -
             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, -
             0.0000000000, -8.2846998703E-4,
             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -
             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -
             9.7479997203E-4, -5.7109999034E-5,
             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, -
             0.0000000000, -1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,
             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -
             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,
             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -
             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,
             5.7109999034E-5, 9.7479997203E-4, -0.0000000000, -1.2145539746E-2, -
             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -
             4.4814897701E-3, 1.4807609841E-2,
             8.2846998703E-4, -0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -
             2.0865600090E-3, 2.3298799060E-3, -
             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,
             0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3]
        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))

    elif n_orientations == 6:
        filters["l"] = 2 * torch.tensor([
            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -
                0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404],
            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,
                0.00410600, -0.00661117, -0.00523281, -0.00244917],
            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,
                0.03277038, 0.01396746, -0.00661117, -0.00387812],
            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,
                0.06426333, 0.03277038, 0.00410600, -0.00944432],
            [-0.00962054, 0.01002988, 0.03981393, 0.08169618, 0.10096540,
                0.08169618, 0.03981393, 0.01002988, -0.00962054],
            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,
                0.06426333, 0.03277038, 0.00410600, -0.00944432],
            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,
                0.03277038, 0.01396746, -0.00661117, -0.00387812],
            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,
                0.00410600, -0.00661117, -0.00523281, -0.00244917],
            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404]]
        ).reshape(1, 1, 9, 9)
        filters["l0"] = torch.tensor([
            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614],
            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
            [-0.03848215, 0.15925570, 0.40304148, 0.15925570, -0.03848215],
            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614]]
        ).reshape(1, 1, 5, 5)
        filters["h0"] = torch.tensor([
            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -
                0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429],
            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,
                0.00631653, -0.00243812, -0.00350017, -0.00113093],
            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -
                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],
            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -
                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],
            [-0.00080639, 0.01261227, -0.00981051, -0.11435863,
                0.81380200, -0.11435863, -0.00981051, 0.01261227, -0.00080639],
            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -
                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],
            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -
                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],
            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,
                0.00631653, -0.00243812, -0.00350017, -0.00113093],
            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429]]
        ).reshape(1, 1, 9, 9)
        filters["b"] = []
        filters["b"].append(torch.tensor([
            0.00277643, 0.00496194, 0.01026699, 0.01455399, 0.01026699, 0.00496194, 0.00277643,
            -0.00986904, -0.00893064, 0.01189859, 0.02755155, 0.01189859, -0.00893064, -0.00986904,
            -0.01021852, -0.03075356, -0.08226445, -
            0.11732297, -0.08226445, -0.03075356, -0.01021852,
            0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
            0.01021852, 0.03075356, 0.08226445, 0.11732297, 0.08226445, 0.03075356, 0.01021852,
            0.00986904, 0.00893064, -0.01189859, -
            0.02755155, -0.01189859, 0.00893064, 0.00986904,
            -0.00277643, -0.00496194, -0.01026699, -0.01455399, -0.01026699, -0.00496194, -0.00277643]
        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor([
            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982,
            -0.00358461, -0.01977507, -0.04084211, -
            0.00228219, 0.03930573, 0.01161195, 0.00128000,
            0.01047717, 0.01486305, -0.04819057, -
            0.12227230, -0.05394139, 0.00853965, -0.00459034,
            0.00790407, 0.04435647, 0.09454202, -0.00000000, -
            0.09454202, -0.04435647, -0.00790407,
            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,
            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,
            -0.01166982, -0.00285723, -0.00182078, -0.01124321, 0.00073141, 0.00640815, 0.00343249]
        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor([
            0.00343249, 0.00358461, -0.01047717, -
            0.00790407, -0.00459034, 0.00128000, 0.01166982,
            0.00640815, 0.01977507, -0.01486305, -
            0.04435647, 0.00853965, 0.01161195, 0.00285723,
            0.00073141, 0.04084211, 0.04819057, -
            0.09454202, -0.05394139, 0.03930573, 0.00182078,
            -0.01124321, 0.00228219, 0.12227230, -
            0.00000000, -0.12227230, -0.00228219, 0.01124321,
            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -
            0.04819057, -0.04084211, -0.00073141,
            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,
            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249]
        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor(
            [-0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643,
             -0.00496194, 0.00893064, 0.03075356, -
             0.00000000, -0.03075356, -0.00893064, 0.00496194,
             -0.01026699, -0.01189859, 0.08226445, -
             0.00000000, -0.08226445, 0.01189859, 0.01026699,
             -0.01455399, -0.02755155, 0.11732297, -
             0.00000000, -0.11732297, 0.02755155, 0.01455399,
             -0.01026699, -0.01189859, 0.08226445, -
             0.00000000, -0.08226445, 0.01189859, 0.01026699,
             -0.00496194, 0.00893064, 0.03075356, -
             0.00000000, -0.03075356, -0.00893064, 0.00496194,
             -0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643]
        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor([
            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249,
            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,
            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -
            0.04819057, -0.04084211, -0.00073141,
            -0.01124321, 0.00228219, 0.12227230, -
            0.00000000, -0.12227230, -0.00228219, 0.01124321,
            0.00073141, 0.04084211, 0.04819057, -
            0.09454202, -0.05394139, 0.03930573, 0.00182078,
            0.00640815, 0.01977507, -0.01486305, -
            0.04435647, 0.00853965, 0.01161195, 0.00285723,
            0.00343249, 0.00358461, -0.01047717, -0.00790407, -0.00459034, 0.00128000, 0.01166982]
        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
        filters["b"].append(torch.tensor([
            -0.01166982, -0.00285723, -0.00182078, -
            0.01124321, 0.00073141, 0.00640815, 0.00343249,
            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,
            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,
            0.00790407, 0.04435647, 0.09454202, -0.00000000, -
            0.09454202, -0.04435647, -0.00790407,
            0.01047717, 0.01486305, -0.04819057, -
            0.12227230, -0.05394139, 0.00853965, -0.00459034,
            -0.00358461, -0.01977507, -0.04084211, -
            0.00228219, 0.03930573, 0.01161195, 0.00128000,
            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982]
        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))

    else:
        raise Exception(
            "Steerable filters not implemented for %d orientations" % n_orientations)

    if filter_type == "trained":
        if size == 5:
            # TODO maybe also train h0 and l0 filters
            filters = crop_steerable_pyramid_filters(filters, 5)
            filters["b"][0] = torch.tensor([
                [-0.0356752239, -0.0223877281, -0.0009542659,
                    0.0244821459, 0.0322226137],
                [-0.0593218654,  0.1245803162, -
                    0.0023863907, -0.1230178699, 0.0589442067],
                [-0.0281576272,  0.2976626456, -
                    0.0020888755, -0.2953369915, 0.0284542721],
                [-0.0586092323,  0.1251581162, -
                    0.0024624448, -0.1227868199, 0.0587830991],
                [-0.0327464789, -0.0223652460, -
                    0.0042342511,  0.0245472137, 0.0359398536]
            ]).reshape(1, 1, 5, 5)
            filters["b"][1] = torch.tensor([
                [3.9758663625e-02,  6.0679119080e-02,  3.0146904290e-02,
                    6.1198268086e-02,  3.6218870431e-02],
                [2.3255519569e-02, -1.2505133450e-01, -
                    2.9738345742e-01, -1.2518258393e-01,  2.3592948914e-02],
                [-1.3602430699e-03, -1.2058277935e-04,  2.6399988565e-04, -
                    2.3791544663e-04,  1.8450465286e-03],
                [-2.1563466638e-02,  1.2572696805e-01,  2.9745018482e-01,
                    1.2458638102e-01, -2.3847281933e-02],
                [-3.7941932678e-02, -6.1060950160e-02, -
                    2.9489086941e-02, -6.0411967337e-02, -3.8459088653e-02]
            ]).reshape(1, 1, 5, 5)

            # Below filters were optimised on 09/02/2021
            # 20K iterations with multiple images at more scales.
            filters["b"][0] = torch.tensor([
                [-4.5508436859e-02, -2.1767273545e-02, -1.9399923622e-04,
                    2.1200872958e-02,  4.5475799590e-02],
                [-6.3554823399e-02,  1.2832683325e-01, -
                    5.3858719184e-05, -1.2809979916e-01,  6.3842624426e-02],
                [-3.4809380770e-02,  2.9954621196e-01,  2.9066693969e-05, -
                    2.9957753420e-01,  3.4806568176e-02],
                [-6.3934154809e-02,  1.2806062400e-01,  9.0917674243e-05, -
                    1.2832444906e-01,  6.3572973013e-02],
                [-4.5492250472e-02, -2.1125273779e-02,  4.2229349492e-04,
                    2.1804777905e-02,  4.5236673206e-02]
            ]).reshape(1, 1, 5, 5)
            filters["b"][1] = torch.tensor([
                [4.8947390169e-02,  6.3575074077e-02,  3.4955859184e-02,
                    6.4085893333e-02,  4.9838040024e-02],
                [2.2061849013e-02, -1.2936264277e-01, -
                    3.0093491077e-01, -1.2997294962e-01,  2.0597217605e-02],
                [-5.1290717238e-05, -1.7305796064e-05,  2.0256420612e-05, -
                    1.1864109547e-04,  7.3973249528e-05],
                [-2.0749464631e-02,  1.2988376617e-01,  3.0080935359e-01,
                    1.2921217084e-01, -2.2159902379e-02],
                [-4.9614857882e-02, -6.4021714032e-02, -
                    3.4676689655e-02, -6.3446544111e-02, -4.8282280564e-02]
            ]).reshape(1, 1, 5, 5)

            # Trained on 17/02/2021 to match fourier pyramid in spatial domain
            filters["b"][0] = torch.tensor([
                [3.3370e-02,  9.3934e-02, -3.5810e-04, -9.4038e-02, -3.3115e-02],
                [1.7716e-01,  3.9378e-01,  6.8461e-05, -3.9343e-01, -1.7685e-01],
                [2.9213e-01,  6.1042e-01,  7.0654e-04, -6.0939e-01, -2.9177e-01],
                [1.7684e-01,  3.9392e-01,  1.0517e-03, -3.9268e-01, -1.7668e-01],
                [3.3000e-02,  9.4029e-02,  7.3565e-04, -9.3366e-02, -3.3008e-02]
            ]).reshape(1, 1, 5, 5) * 0.1

            filters["b"][1] = torch.tensor([
                [0.0331,  0.1763,  0.2907,  0.1753,  0.0325],
                [0.0941,  0.3932,  0.6079,  0.3904,  0.0922],
                [0.0008,  0.0009, -0.0010, -0.0025, -0.0015],
                [-0.0929, -0.3919, -0.6097, -0.3944, -0.0946],
                [-0.0328, -0.1760, -0.2915, -0.1768, -0.0333]
            ]).reshape(1, 1, 5, 5) * 0.1

        else:
            raise Exception(
                "Trained filters not implemented for size %d" % size)

    if filter_type == "cropped":
        filters = crop_steerable_pyramid_filters(filters, size)

    return filters