Skip to content

odak.learn.wave

angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution with Angular Spectrum method for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution with Angular Spectrum method for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field
               A complex field.
               The expected size is [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate bandlimited angular spectrum based beam propagation. For more 
    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.

    Parameters
    ----------
    field            : torch.complex
                       A complex field.
                       The expected size is [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Bandlimited Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

custom(field, kernel, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • kernel
               Custom complex kernel for beam propagation.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def custom(field, kernel, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    kernel           : torch.complex
                       Custom complex kernel for beam propagation.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    if type(kernel) == type(None):
        H = torch.ones(field.shape).to(field.device)
    else:
        H = kernel * aperture
    U1 = torch.fft.fftshift(torch.fft.fft2(field)) * aperture
    if zero_padding == False:
        U2 = H * U1
    elif zero_padding == True:
        U2 = zero_pad(H * U1)
    result = torch.fft.ifft2(torch.fft.ifftshift(U2))
    return result

fraunhofer(field, k, distance, dx, wavelength)

A definition to calculate light transport usin Fraunhofer approximation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def fraunhofer(field, k, distance, dx, wavelength):
    """
    A definition to calculate light transport usin Fraunhofer approximation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).
    """
    nv, nu = field.shape[-1], field.shape[-2]
    x = torch.linspace(-nv*dx/2, nv*dx/2, nv, dtype=torch.float32)
    y = torch.linspace(-nu*dx/2, nu*dx/2, nu, dtype=torch.float32)
    Y, X = torch.meshgrid(y, x, indexing='ij')
    Z = torch.pow(X, 2) + torch.pow(Y, 2)
    c = 1. / (1j * wavelength * distance) * torch.exp(1j * k * 0.5 / distance * Z)
    c = c.to(field.device)
    result = c * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(field))) * dx ** 2
    return result

gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel')

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

Parameters:

  • field
               Complex field (MxN).
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • slm_range
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Type of the propagation (see odak.learn.wave.propagate_beam).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

  • reconstruction ( cfloat ) –

    Calculated reconstruction using calculated hologram.

Source code in odak/learn/wave/classical.py
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel'):
    """
    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

    Parameters
    ----------
    field            : torch.cfloat
                       Complex field (MxN).
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    slm_range        : float
                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    propagation_type : str
                       Type of the propagation (see odak.learn.wave.propagate_beam).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    reconstruction   : torch.cfloat
                       Calculated reconstruction using calculated hologram. 
    """
    k = wavenumber(wavelength)
    reconstruction = field
    for i in range(n_iterations):
        hologram = propagate_beam(
            reconstruction, k, -distance, dx, wavelength, propagation_type)
        reconstruction = propagate_beam(
            hologram, k, distance, dx, wavelength, propagation_type)
        reconstruction = set_amplitude(reconstruction, field)
    reconstruction = propagate_beam(
        hologram, k, distance, dx, wavelength, propagation_type)
    return hologram, reconstruction

get_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
    H = H.to(device)
    return H

get_band_limited_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.band_limited_angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.band_limited_angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    x = dx * float(nu)
    y = dx * float(nv)
    fx = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * x),
                         1 / (2 * dx) - 0.5 / (2 * x),
                         nu,
                         dtype = torch.float32,
                         device = device
                        )
    fy = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * y),
                        1 / (2 * dx) - 0.5 / (2 * y),
                        nv,
                        dtype = torch.float32,
                        device = device
                       )
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    HH_exp = 2 * torch.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))
    distance = torch.tensor([distance], device = device)
    H_exp = torch.mul(HH_exp, distance)
    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength
    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength
    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()
    H = generate_complex_field(H_filter, H_exp)
    return H

get_impulse_response_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), scale=1, aperture_samples=[20, 20, 5, 5])

Helper function for odak.learn.wave.impulse_response_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • scale
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    
  • aperture_samples
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu'), scale = 1, aperture_samples = [20, 20, 5, 5]):
    """
    Helper function for odak.learn.wave.impulse_response_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    scale              : int
                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    aperture_samples   : list
                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.

    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    k = wavenumber(wavelength)
    distance = torch.as_tensor(distance, device = device)
    length_x, length_y = (torch.tensor(dx * nu, device = device), torch.tensor(dx * nv, device = device))
    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
    X, Y = torch.meshgrid(x, y, indexing = 'ij')
    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device)
    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device)
    h = torch.zeros(nu * scale, nv * scale, dtype = torch.complex64, device = device)
    pxs = torch.linspace(- dx / 2., dx / 2., 5, device = device)
    pys = torch.linspace(- dx / 2., dx / 2., 5, device = device)
    for wx in tqdm(wxs):
        for wy in wys:
            for px in pxs:
                for py in pys:
                    r = (X + px - wx) ** 2 + (Y + py - wy) ** 2
                    h += 1. / (1j * wavelength * distance) * torch.exp(1j * k / (2 * distance) * r) 
    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
    return H

get_light_kernels(wavelengths, distances, pixel_pitches, resolution=[1080, 1920], resolution_factor=1, samples=[50, 50, 5, 5], propagation_type='Bandlimited Angular Spectrum', kernel_type='spatial', device=torch.device('cpu'))

Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

Parameters:

  • wavelengths
                 A list of wavelengths.
    
  • distances
                 A list of propagation distances.
    
  • pixel_pitches
                 A list of pixel_pitches.
    
  • resolution
                 Resolution of the light transport kernel.
    
  • resolution_factor
                 If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
    
  • samples
                 If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
    
  • propagation_type
                 Propagation type. For more, see odak.learn.wave.propagate_beam().
    
  • kernel_type
                 If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
    
  • device
                 Device used for computation (i.e., cpu, cuda).
    

Returns:

  • light_kernels_amplitude ( tensor ) –

    Amplitudes of the light kernels generated [w x d x p x m x n].

  • light_kernels_phase ( tensor ) –

    Phases of the light kernels generated [w x d x p x m x n].

  • light_kernels_complex ( tensor ) –

    Complex light kernels generated [w x d x p x m x n].

  • light_parameters ( tensor ) –

    Parameters of each pixel in light_kernels* [w x d x p x m x n x 5]. Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.

Source code in odak/learn/wave/classical.py
def get_light_kernels(
                      wavelengths,
                      distances,
                      pixel_pitches,
                      resolution = [1080, 1920],
                      resolution_factor = 1,
                      samples = [50, 50, 5, 5],
                      propagation_type = 'Bandlimited Angular Spectrum',
                      kernel_type = 'spatial',
                      device = torch.device('cpu')
                     ):
    """
    Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

    Parameters
    ----------
    wavelengths        : list
                         A list of wavelengths.
    distances          : list
                         A list of propagation distances.
    pixel_pitches      : list
                         A list of pixel_pitches.
    resolution         : list
                         Resolution of the light transport kernel.
    resolution_factor  : int
                         If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
    samples            : list
                         If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
    propagation_type   : str
                         Propagation type. For more, see odak.learn.wave.propagate_beam().
    kernel_type        : str
                         If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
    device             : torch.device
                         Device used for computation (i.e., cpu, cuda).

    Returns
    -------
    light_kernels_amplitude : torch.tensor
                              Amplitudes of the light kernels generated [w x d x p x m x n].
    light_kernels_phase     : torch.tensor
                              Phases of the light kernels generated [w x d x p x m x n].
    light_kernels_complex   : torch.tensor
                              Complex light kernels generated [w x d x p x m x n].
    light_parameters        : torch.tensor
                              Parameters of each pixel in light_kernels* [w x d x p x m x n x 5].  Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.
    """
    if propagation_type != 'Impulse Response Fresnel':
        resolution_factor = 1
    light_kernels_complex = torch.zeros(            
                                        len(wavelengths),
                                        len(distances),
                                        len(pixel_pitches),
                                        resolution[0] * resolution_factor,
                                        resolution[1] * resolution_factor,
                                        dtype = torch.complex64,
                                        device = device
                                       )
    light_parameters = torch.zeros(
                                   len(wavelengths),
                                   len(distances),
                                   len(pixel_pitches),
                                   resolution[0] * resolution_factor,
                                   resolution[1] * resolution_factor,
                                   5,
                                   dtype = torch.float32,
                                   device = device
                                  )
    for wavelength_id, distance_id, pixel_pitch_id in itertools.product(
                                                                        range(len(wavelengths)),
                                                                        range(len(distances)),
                                                                        range(len(pixel_pitches)),
                                                                       ):
        pixel_pitch = pixel_pitches[pixel_pitch_id]
        wavelength = wavelengths[wavelength_id]
        distance = distances[distance_id]
        kernel_fourier = get_propagation_kernel(
                                                nu = resolution[0],
                                                nv = resolution[1],
                                                dx = pixel_pitch,
                                                wavelength = wavelength,
                                                distance = distance,
                                                device = device,
                                                propagation_type = propagation_type,
                                                scale = resolution_factor,
                                                samples = samples
                                               )
        if kernel_type == 'spatial':
            kernel = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(kernel_fourier)))
        elif kernel_type == 'fourier':
            kernel = kernel_fourier
        else:
            logging.warning('Unknown kernel type requested.')
            raise ValueError('Unknown kernel type requested.')
        kernel_amplitude = calculate_amplitude(kernel)
        kernel_phase = calculate_phase(kernel) % (2 * torch.pi)
        light_kernels_complex[wavelength_id, distance_id, pixel_pitch_id] = kernel
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 0] = wavelength
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 1] = distance
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 2] = pixel_pitch
        x = torch.linspace(-1., 1., resolution[0] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[0]
        y = torch.linspace(-1., 1., resolution[1] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[1]
        X, Y = torch.meshgrid(x, y, indexing = 'ij')
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 3] = X
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 4] = Y
    light_kernels_amplitude = calculate_amplitude(light_kernels_complex)
    light_kernels_phase = calculate_phase(light_kernels_complex) % (2. * torch.pi)
    return light_kernels_amplitude, light_kernels_phase, light_kernels_complex, light_parameters

get_propagation_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), propagation_type='Bandlimited Angular Spectrum', scale=1, samples=[20, 20, 5, 5])

Get propagation kernel for the propagation type.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • propagation_type
                 Propagation type.
                 The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    
  • scale
                 Scale factor for scaled beam propagation.
    
  • samples
                 When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    

Returns:

  • kernel ( tensor ) –

    Complex kernel for the given propagation type.

Source code in odak/learn/wave/classical.py
def get_propagation_kernel(
                           nu, 
                           nv, 
                           dx = 8e-6, 
                           wavelength = 515e-9, 
                           distance = 0., 
                           device = torch.device('cpu'), 
                           propagation_type = 'Bandlimited Angular Spectrum', 
                           scale = 1,
                           samples = [20, 20, 5, 5]
                          ):
    """
    Get propagation kernel for the propagation type.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    propagation_type   : str
                         Propagation type.
                         The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    scale              : int
                         Scale factor for scaled beam propagation.
    samples            : list
                         When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.


    Returns
    -------
    kernel             : torch.tensor
                         Complex kernel for the given propagation type.
    """
    if propagation_type == 'Bandlimited Angular Spectrum':
        kernel = get_band_limited_angular_spectrum_kernel(nu, nv, dx, wavelength, distance, device)
    elif propagation_type == 'Angular Spectrum':
        kernel = get_angular_spectrum_kernel(nu, nv, dx, wavelength, distance, device)
    elif propagation_type == 'Transfer Function Fresnel':
        kernel = get_transfer_function_fresnel_kernel(nu, nv, dx, wavelength, distance, device)
    elif propagation_type == 'Impulse Response Fresnel':
        kernel = get_impulse_response_fresnel_kernel(nu, nv, dx, wavelength, distance, device, scale = scale, aperture_samples = samples)
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    return kernel

get_transfer_function_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.transfer_function_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.transfer_function_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing = 'ij')
    k = wavenumber(wavelength)
    H = torch.exp(-1j * distance * (k - torch.pi * wavelength * (FX ** 2 + FY ** 2)))
    return H

impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5])

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Impulse Response Fresnel',
                               device = field.device,
                               scale = scale,
                               samples = samples
                              )
    if scale > 1:
        field_amplitude = calculate_amplitude(field)
        field_phase = calculate_phase(field)
        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
        field_scale_phase = torch.zeros_like(field_scale_amplitude)
        field_scale_amplitude[::scale, ::scale] = field_amplitude
        field_scale_phase[::scale, ::scale] = field_phase
        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
    else:
        field_scale = field
    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
    return result

point_wise(target, wavelength, distance, dx, device, lens_size=401)

Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

Parameters:

  • target
               float input target to be converted into a hologram (Target should be in range of 0 and 1).
    
  • wavelength
               Wavelength of the electric field.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • device
               Device type (cuda or cpu)`.
    
  • lens_size
               Size of lens for masking sub holograms(in pixels).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

Source code in odak/learn/wave/classical.py
def point_wise(target, wavelength, distance, dx, device, lens_size=401):
    """
    Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

    Parameters
    ----------
    target           : torch.float
                       float input target to be converted into a hologram (Target should be in range of 0 and 1).
    wavelength       : float
                       Wavelength of the electric field.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    device           : torch.device
                       Device type (cuda or cpu)`.
    lens_size        : int
                       Size of lens for masking sub holograms(in pixels).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    """
    target = zero_pad(target)
    nx, ny = target.shape
    k = wavenumber(wavelength)
    ones = torch.ones(target.shape, requires_grad=False).to(device)
    x = torch.linspace(-nx/2, nx/2, nx).to(device)
    y = torch.linspace(-ny/2, ny/2, ny).to(device)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = (X**2+Y**2)**0.5
    mask = (torch.abs(Z) <= lens_size)
    mask[mask > 1] = 1
    fz = quadratic_phase_function(nx, ny, k, focal=-distance, dx=dx).to(device)
    A = torch.nan_to_num(target**0.5, nan=0.0)
    fz = mask*fz
    FA = torch.fft.fft2(torch.fft.fftshift(A))
    FFZ = torch.fft.fft2(torch.fft.fftshift(fz))
    H = torch.mul(FA, FFZ)
    hologram = torch.fft.ifftshift(torch.fft.ifft2(H))
    hologram = crop_center(hologram)
    return hologram

propagate_beam(field, k, distance, dx, wavelength, propagation_type='Bandlimited Angular Spectrum', kernel=None, zero_padding=[True, False, True], aperture=1.0, scale=1, samples=[20, 20, 5, 5])

Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • propagation_type (str, default: 'Bandlimited Angular Spectrum' ) –
               Type of the propagation.
               The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    
  • kernel
               Custom complex kernel.
    
  • zero_padding
               Zero padding the input field if the first item in the list set True.
               Zero padding in the Fourier domain if the second item in the list set to True.
               Cropping the result with half resolution if the third item in the list is set to true.
               Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    
  • aperture
               Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
               If provided as a floating point 1, there will be no aperture in Fourier domain.
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def propagate_beam(
                   field,
                   k,
                   distance,
                   dx,
                   wavelength,
                   propagation_type='Bandlimited Angular Spectrum',
                   kernel = None,
                   zero_padding = [True, False, True],
                   aperture = 1.,
                   scale = 1,
                   samples = [20, 20, 5, 5]
                  ):
    """
    Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    propagation_type : str
                       Type of the propagation.
                       The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    kernel           : torch.complex
                       Custom complex kernel.
    zero_padding     : list
                       Zero padding the input field if the first item in the list set True.
                       Zero padding in the Fourier domain if the second item in the list set to True.
                       Cropping the result with half resolution if the third item in the list is set to true.
                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    aperture         : torch.tensor
                       Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
                       If provided as a floating point 1, there will be no aperture in Fourier domain.
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    if zero_padding[0]:
        field = zero_pad(field)
    if propagation_type == 'Angular Spectrum':
        result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Bandlimited Angular Spectrum':
        result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Impulse Response Fresnel':
        result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
    elif propagation_type == 'Transfer Function Fresnel':
        result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'custom':
        result = custom(field, kernel, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Fraunhofer':
        result = fraunhofer(field, k, distance, dx, wavelength)
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    if zero_padding[2]:
        result = crop_center(result)
    return result

shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None)

Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in here and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

Parameters:

  • phase
               Phase value of a phase-only hologram.
    
  • depth_shift
               Distance in meters.
    
  • pixel_pitch
               Pixel pitch size in meters.
    
  • wavelength
               Wavelength of light.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Beam propagation type. For more see odak.learn.wave.propagate_beam().
    
  • kernel_length
               Kernel length for the Gaussian blur kernel.
    
  • sigma
               Standard deviation for the Gaussian blur kernel.
    
  • amplitude
               Amplitude value of a complex hologram.
    
Source code in odak/learn/wave/classical.py
def shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None):
    """
    Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

    Parameters
    ----------
    phase            : torch.tensor
                       Phase value of a phase-only hologram.
    depth_shift      : float
                       Distance in meters.
    pixel_pitch      : float
                       Pixel pitch size in meters.
    wavelength       : float
                       Wavelength of light.
    propagation_type : str
                       Beam propagation type. For more see odak.learn.wave.propagate_beam().
    kernel_length    : int
                       Kernel length for the Gaussian blur kernel.
    sigma            : float
                       Standard deviation for the Gaussian blur kernel.
    amplitude        : torch.tensor
                       Amplitude value of a complex hologram.
    """
    if type(amplitude) == type(None):
        amplitude = torch.ones_like(phase)
    hologram = generate_complex_field(amplitude, phase)
    k = wavenumber(wavelength)
    hologram_padded = zero_pad(hologram)
    shifted_field_padded = propagate_beam(
                                          hologram_padded,
                                          k,
                                          depth_shift,
                                          pixel_pitch,
                                          wavelength,
                                          propagation_type
                                         )
    shifted_field = crop_center(shifted_field_padded)
    phase_shift = torch.exp(torch.tensor([-2 * torch.pi * depth_shift / wavelength]).to(phase.device))
    shift = torch.cos(phase_shift) + 1j * torch.sin(phase_shift)
    shifted_complex_hologram = shifted_field * shift

    if kernel_length > 0 and sigma >0:
        blur_kernel = generate_2d_gaussian(
                                           [kernel_length, kernel_length],
                                           [sigma, sigma]
                                          ).to(phase.device)
        blur_kernel = blur_kernel.unsqueeze(0)
        blur_kernel = blur_kernel.unsqueeze(0)
        field_imag = torch.imag(shifted_complex_hologram)
        field_real = torch.real(shifted_complex_hologram)
        field_imag = field_imag.unsqueeze(0)
        field_imag = field_imag.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_imag = torch.nn.functional.conv2d(field_imag, blur_kernel, padding='same')
        field_real = torch.nn.functional.conv2d(field_real, blur_kernel, padding='same')
        shifted_complex_hologram = torch.complex(field_real, field_imag)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)

    shifted_amplitude = calculate_amplitude(shifted_complex_hologram)
    shifted_amplitude = shifted_amplitude / torch.amax(shifted_amplitude, [0,1])

    shifted_phase = calculate_phase(shifted_complex_hologram)
    phase_zero_mean = shifted_phase - torch.mean(shifted_phase)

    phase_offset = torch.arccos(shifted_amplitude)
    phase_low = phase_zero_mean - phase_offset
    phase_high = phase_zero_mean + phase_offset

    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only

stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type='Bandlimited Angular Spectrum', n_iteration=100, loss_function=None, learning_rate=0.1)

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

  • target
                        Target field amplitude [m x n].
                        Keep the target values between zero and one.
    
  • wavelength
                        Set if the converted array requires gradient.
    
  • distance
                        Hologram plane distance wrt SLM plane.
    
  • pixel_pitch
                        SLM pixel pitch in meters.
    
  • propagation_type
                        Type of the propagation (see odak.learn.wave.propagate_beam()).
    
  • n_iteration
                        Number of iteration.
    
  • loss_function
                        If none it is set to be l2 loss.
    
  • learning_rate
                        Learning rate.
    

Returns:

  • hologram ( Tensor ) –

    Phase only hologram as torch array

  • reconstruction_intensity ( Tensor ) –

    Reconstruction as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):
    """
    Definition to generate phase and reconstruction from target image via stochastic gradient descent.

    Parameters
    ----------
    target                    : torch.Tensor
                                Target field amplitude [m x n].
                                Keep the target values between zero and one.
    wavelength                : double
                                Set if the converted array requires gradient.
    distance                  : double
                                Hologram plane distance wrt SLM plane.
    pixel_pitch               : float
                                SLM pixel pitch in meters.
    propagation_type          : str
                                Type of the propagation (see odak.learn.wave.propagate_beam()).
    n_iteration:              : int
                                Number of iteration.
    loss_function:            : function
                                If none it is set to be l2 loss.
    learning_rate             : float
                                Learning rate.

    Returns
    -------
    hologram                  : torch.Tensor
                                Phase only hologram as torch array

    reconstruction_intensity  : torch.Tensor
                                Reconstruction as torch array

    """
    phase = torch.randn_like(target, requires_grad = True)
    k = wavenumber(wavelength)
    optimizer = torch.optim.Adam([phase], lr = learning_rate)
    if type(loss_function) == type(None):
        loss_function = torch.nn.MSELoss()
    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)
    for i in t:
        optimizer.zero_grad()
        hologram = generate_complex_field(1., phase)
        reconstruction = propagate_beam(
                                        hologram, 
                                        k, 
                                        distance, 
                                        pixel_pitch, 
                                        wavelength, 
                                        propagation_type, 
                                        zero_padding = [True, False, True]
                                       )
        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
        loss = loss_function(reconstruction_intensity, target)
        description = "Loss:{:.4f}".format(loss.item())
        loss.backward(retain_graph = True)
        optimizer.step()
        t.set_description(description)
    print(description)
    torch.no_grad()
    hologram = generate_complex_field(1., phase)
    reconstruction = propagate_beam(
                                    hologram, 
                                    k, 
                                    distance, 
                                    pixel_pitch, 
                                    wavelength, 
                                    propagation_type, 
                                    zero_padding = [True, False, True]
                                   )
    return hologram, reconstruction

transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Transfer Function Fresnel',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

blazed_grating(nx, ny, levels=2, axis='x')

A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • levels
           Number of pixels.
    
  • axis
           Axis of glazed grating. It could be `x` or `y`.
    
Source code in odak/learn/wave/lens.py
def blazed_grating(nx, ny, levels = 2, axis = 'x'):
    """
    A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.


    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    levels       : int
                   Number of pixels.
    axis         : str
                   Axis of glazed grating. It could be `x` or `y`.

    """
    if levels < 2:
        levels = 2
    x = (torch.abs(torch.arange(-nx, 0)) % levels) / levels * (2 * np.pi)
    y = (torch.abs(torch.arange(-ny, 0)) % levels) / levels * (2 * np.pi)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'x':
        blazed_grating = torch.exp(1j * X)
    elif axis == 'y':
        blazed_grating = torch.exp(1j * Y)
    return blazed_grating

linear_grating(nx, ny, every=2, add=None, axis='x')

A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • every
         Add the add value at every given number.
    
  • add
         Angle to be added.
    
  • axis
         Axis eiter X,Y or both.
    

Returns:

  • field ( tensor ) –

    Linear grating term.

Source code in odak/learn/wave/lens.py
def linear_grating(nx, ny, every = 2, add = None, axis = 'x'):
    """
    A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    every      : int
                 Add the add value at every given number.
    add        : float
                 Angle to be added.
    axis       : string
                 Axis eiter X,Y or both.

    Returns
    ----------
    field      : torch.tensor
                 Linear grating term.
    """
    if isinstance(add, type(None)):
        add = np.pi
    grating = torch.zeros((nx, ny), dtype=torch.complex64)
    if axis == 'x':
        grating[::every, :] = torch.exp(torch.tensor(1j*add))
    if axis == 'y':
        grating[:, ::every] = torch.exp(torch.tensor(1j*add))
    if axis == 'xy':
        checker = np.indices((nx, ny)).sum(axis=0) % every
        checker = torch.from_numpy(checker)
        checker += 1
        checker = checker % 2
        grating = torch.exp(1j*checker*add)
    return grating

prism_grating(nx, ny, k, angle, dx=0.001, axis='x', phase_offset=0.0)

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • k
           See odak.wave.wavenumber for more.
    
  • angle
           Tilt angle of the prism in degrees.
    
  • dx
           Pixel pitch.
    
  • axis
           Axis of the prism.
    
  • phase_offset (float, default: 0.0 ) –
           Phase offset in angles. Default is zero.
    

Returns:

  • prism ( tensor ) –

    Generated phase function for a prism.

Source code in odak/learn/wave/lens.py
def prism_grating(nx, ny, k, angle, dx = 0.001, axis = 'x', phase_offset = 0.):
    """
    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    k            : odak.wave.wavenumber
                   See odak.wave.wavenumber for more.
    angle        : float
                   Tilt angle of the prism in degrees.
    dx           : float
                   Pixel pitch.
    axis         : str
                   Axis of the prism.
    phase_offset : float
                   Phase offset in angles. Default is zero.

    Returns
    ----------
    prism        : torch.tensor
                   Generated phase function for a prism.
    """
    angle = torch.deg2rad(torch.tensor([angle]))
    phase_offset = torch.deg2rad(torch.tensor([phase_offset]))
    x = torch.arange(0, nx) * dx
    y = torch.arange(0, ny) * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'y':
        phase = k * torch.sin(angle) * Y + phase_offset
        prism = torch.exp(-1j * phase)
    elif axis == 'x':
        phase = k * torch.sin(angle) * X + phase_offset
        prism = torch.exp(-1j * phase)
    return prism

quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • k
         See odak.wave.wavenumber for more.
    
  • focal
         Focal length of the quadratic phase function.
    
  • dx
         Pixel pitch.
    
  • offset
         Deviation from the center along X and Y axes.
    

Returns:

  • function ( tensor ) –

    Generated quadratic phase function.

Source code in odak/learn/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
    """ 
    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    k          : odak.wave.wavenumber
                 See odak.wave.wavenumber for more.
    focal      : float
                 Focal length of the quadratic phase function.
    dx         : float
                 Pixel pitch.
    offset     : list
                 Deviation from the center along X and Y axes.

    Returns
    -------
    function   : torch.tensor
                 Generated quadratic phase function.
    """
    size = [nx, ny]
    x = torch.linspace(-size[0] * dx / 2, size[0] * dx / 2, size[0]) - offset[1] * dx
    y = torch.linspace(-size[1] * dx / 2, size[1] * dx / 2, size[1]) - offset[0] * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = X**2 + Y**2
    qwf = torch.exp(-0.5j * k / focal * Z)
    return qwf

multiplane_loss

Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class multiplane_loss():
    """
    Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """

    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
                 multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image      : torch.tensor
                            Color target image [3 x m x n].
        target_depth      : torch.tensor
                            Monochrome target depth, same resolution as target_image.
        target_blur_size  : int
                            Maximum target blur size.
        blur_ratio        : float
                            Blur ratio, a value between zero and one.
        number_of_planes  : int
                            Number of planes.
        weights           : list
                            Weights of the loss function.
        multiplier        : float
                            Multiplier to multipy with targets.
        scheme            : str
                            The type of the loss, `naive` without defocus or `defocus` with defocus.
        reduction         : str
                            Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        device            : torch.device
                            Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.weights          = weights
        self.reduction        = reduction
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
                sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.

        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        l2 = self.weights[0] * self.loss_function(image, target)
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
        l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
        loss = l2 + l2_mask + l2_cor
        return loss

__call__(image, target, plane_id=None)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.

    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    l2 = self.weights[0] * self.loss_function(image, target)
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
    l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
    loss = l2 + l2_mask + l2_cor
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, weights=[1.0, 2.1, 0.6], multiplier=1.0, scheme='defocus', reduction='mean', device=torch.device('cpu'))

Parameters:

  • target_image
                Color target image [3 x m x n].
    
  • target_depth
                Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                Maximum target blur size.
    
  • blur_ratio
                Blur ratio, a value between zero and one.
    
  • number_of_planes
                Number of planes.
    
  • weights
                Weights of the loss function.
    
  • multiplier
                Multiplier to multipy with targets.
    
  • scheme
                The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • reduction
                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • device
                Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
             multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image      : torch.tensor
                        Color target image [3 x m x n].
    target_depth      : torch.tensor
                        Monochrome target depth, same resolution as target_image.
    target_blur_size  : int
                        Maximum target blur size.
    blur_ratio        : float
                        Blur ratio, a value between zero and one.
    number_of_planes  : int
                        Number of planes.
    weights           : list
                        Weights of the loss function.
    multiplier        : float
                        Multiplier to multipy with targets.
    scheme            : str
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    reduction         : str
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    device            : torch.device
                        Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.weights          = weights
    self.reduction        = reduction
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
            sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

perceptual_multiplane_loss

Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class perceptual_multiplane_loss():
    """
    Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """

    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
                 base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
                 additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image            : torch.tensor
                                    Color target image [3 x m x n].
        target_depth            : torch.tensor
                                    Monochrome target depth, same resolution as target_image.
        target_blur_size        : int
                                    Maximum target blur size.
        blur_ratio              : float
                                    Blur ratio, a value between zero and one.
        number_of_planes        : int
                                    Number of planes.
        multiplier              : float
                                    Multiplier to multipy with targets.
        scheme                  : str
                                    The type of the loss, `naive` without defocus or `defocus` with defocus.
        base_loss_weights       : list
                                    Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
        additional_loss_weights : dict
                                    Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
        reduction               : str
                                    Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        return_components       : bool
                                    If True (False by default), returns the components of the loss as a dict.
        device                  : torch.device
                                    Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.reduction        = reduction
        if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
            logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
            self.reduction = 'mean'
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.base_loss_weights = base_loss_weights
        self.additional_loss_weights = additional_loss_weights
        self.return_components = return_components
        self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
        self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
        for key in self.additional_loss_weights.keys():
            if key == 'cvvdp':
                self.cvvdp = CVVDP()
            if key == 'fvvdp':
                self.fvvdp = FVVDP()
            if key == 'lpips':
                self.lpips = LPIPS()
            if key == 'psnr':
                self.psnr = PSNR()
            if key == 'ssim':
                self.ssim = SSIM()
            if key == 'msssim':
                self.msssim = MSSSIM()

    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
                sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.

        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        loss_components = {}
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
        l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
        l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
        loss_components['l2'] = l2
        loss_components['l2_mask'] = l2_mask
        loss_components['l2_cor'] = l2_cor

        l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
        l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
        l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
        loss_components['l1'] = l1
        loss_components['l1_mask'] = l1_mask
        loss_components['l1_cor'] = l1_cor

        for key in self.additional_loss_weights.keys():
            if key == 'cvvdp':
                loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
                loss_components['cvvdp'] = loss_cvvdp
            if key == 'fvvdp':
                loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
                loss_components['fvvdp'] = loss_fvvdp
            if key == 'lpips':
                loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
                loss_components['lpips'] = loss_lpips
            if key == 'psnr':
                loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
                loss_components['psnr'] = loss_psnr
            if key == 'ssim':
                loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
                loss_components['ssim'] = loss_ssim
            if key == 'msssim':
                loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
                loss_components['msssim'] = loss_msssim

        loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0)

        if self.return_components:
            return loss, loss_components
        return loss

__call__(image, target, plane_id=None)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.

    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    loss_components = {}
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
    l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
    l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
    loss_components['l2'] = l2
    loss_components['l2_mask'] = l2_mask
    loss_components['l2_cor'] = l2_cor

    l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
    l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
    l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
    loss_components['l1'] = l1
    loss_components['l1_mask'] = l1_mask
    loss_components['l1_cor'] = l1_cor

    for key in self.additional_loss_weights.keys():
        if key == 'cvvdp':
            loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
            loss_components['cvvdp'] = loss_cvvdp
        if key == 'fvvdp':
            loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
            loss_components['fvvdp'] = loss_fvvdp
        if key == 'lpips':
            loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
            loss_components['lpips'] = loss_lpips
        if key == 'psnr':
            loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
            loss_components['psnr'] = loss_psnr
        if key == 'ssim':
            loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
            loss_components['ssim'] = loss_ssim
        if key == 'msssim':
            loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
            loss_components['msssim'] = loss_msssim

    loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0)

    if self.return_components:
        return loss, loss_components
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, multiplier=1.0, scheme='defocus', base_loss_weights={'base_l2_loss': 1.0, 'loss_l2_mask': 1.0, 'loss_l2_cor': 1.0, 'base_l1_loss': 1.0, 'loss_l1_mask': 1.0, 'loss_l1_cor': 1.0}, additional_loss_weights={'cvvdp': 1.0}, reduction='mean', return_components=False, device=torch.device('cpu'))

Parameters:

  • target_image
                        Color target image [3 x m x n].
    
  • target_depth
                        Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                        Maximum target blur size.
    
  • blur_ratio
                        Blur ratio, a value between zero and one.
    
  • number_of_planes
                        Number of planes.
    
  • multiplier
                        Multiplier to multipy with targets.
    
  • scheme
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • base_loss_weights
                        Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
    
  • additional_loss_weights (dict, default: {'cvvdp': 1.0} ) –
                        Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
    
  • reduction
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • return_components
                        If True (False by default), returns the components of the loss as a dict.
    
  • device
                        Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
             base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
             additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image            : torch.tensor
                                Color target image [3 x m x n].
    target_depth            : torch.tensor
                                Monochrome target depth, same resolution as target_image.
    target_blur_size        : int
                                Maximum target blur size.
    blur_ratio              : float
                                Blur ratio, a value between zero and one.
    number_of_planes        : int
                                Number of planes.
    multiplier              : float
                                Multiplier to multipy with targets.
    scheme                  : str
                                The type of the loss, `naive` without defocus or `defocus` with defocus.
    base_loss_weights       : list
                                Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
    additional_loss_weights : dict
                                Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
    reduction               : str
                                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    return_components       : bool
                                If True (False by default), returns the components of the loss as a dict.
    device                  : torch.device
                                Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.reduction        = reduction
    if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
        logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
        self.reduction = 'mean'
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.base_loss_weights = base_loss_weights
    self.additional_loss_weights = additional_loss_weights
    self.return_components = return_components
    self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
    self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
    for key in self.additional_loss_weights.keys():
        if key == 'cvvdp':
            self.cvvdp = CVVDP()
        if key == 'fvvdp':
            self.fvvdp = FVVDP()
        if key == 'lpips':
            self.lpips = LPIPS()
        if key == 'psnr':
            self.psnr = PSNR()
        if key == 'ssim':
            self.ssim = SSIM()
        if key == 'msssim':
            self.msssim = MSSSIM()

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
            sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

phase_gradient

Bases: Module

The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

Source code in odak/learn/wave/loss.py
class phase_gradient(nn.Module):

    """
    The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude. 

    This implements a convolution of the phase with a kernel.

    The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.
    """


    def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
        """
        Parameters
        ----------
        kernel                  : torch.tensor
                                    Convolution filter kernel, 3 by 3 Laplacian kernel by default.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(phase_gradient, self).__init__()
        self.device = device
        self.loss = loss
        if kernel == None:
            self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
        else:
            if len(kernel.shape) == 4:
                self.kernel = kernel
            else:
                self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
        self.kernel = Variable(self.kernel.to(self.device))


    def forward(self, phase):
        """
        Calculates the phase gradient Loss.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(phase.shape) == 2:
            phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
        edge_detect = self.functional_conv2d(phase)
        loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
        return loss_value


    def functional_conv2d(self, phase):
        """
        Calculates the gradient of the phase.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        edge_detect              : torch.tensor
                                    The computed phase gradient.
        """
        edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
        return edge_detect

__init__(kernel=None, loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel
                        Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
    """
    Parameters
    ----------
    kernel                  : torch.tensor
                                Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(phase_gradient, self).__init__()
    self.device = device
    self.loss = loss
    if kernel == None:
        self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
    else:
        if len(kernel.shape) == 4:
            self.kernel = kernel
        else:
            self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
    self.kernel = Variable(self.kernel.to(self.device))

forward(phase)

Calculates the phase gradient Loss.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, phase):
    """
    Calculates the phase gradient Loss.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(phase.shape) == 2:
        phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
    edge_detect = self.functional_conv2d(phase)
    loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
    return loss_value

functional_conv2d(phase)

Calculates the gradient of the phase.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • edge_detect ( tensor ) –

    The computed phase gradient.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, phase):
    """
    Calculates the gradient of the phase.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    edge_detect              : torch.tensor
                                The computed phase gradient.
    """
    edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
    return edge_detect

speckle_contrast

Bases: Module

The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

We refer to the following paper:

Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0.

Source code in odak/learn/wave/loss.py
class speckle_contrast(nn.Module):

    """
    The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

    We refer to the following paper:

    Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0. 
    """


    def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
        """
        Parameters
        ----------
        kernel_size             : torch.tensor
                                    Convolution filter kernel size, 11 by 11 average kernel by default.
        step_size               : tuple
                                    Convolution stride in height and width direction.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(speckle_contrast, self).__init__()
        self.device = device
        self.loss = loss
        self.step_size = step_size
        self.kernel_size = kernel_size
        self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
        self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))


    def forward(self, intensity):
        """
        Calculates the speckle contrast Loss.

        Parameters
        ----------
        intensity               : torch.tensor
                                    intensity of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(intensity.shape) == 2:
            intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
        Speckle_C = self.functional_conv2d(intensity)
        loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
        return loss_value


    def functional_conv2d(self, intensity):
        """
        Calculates the speckle contrast of the intensity.

        Parameters
        ----------
        intensity                : torch.tensor
                                    Intensity of the complex field.

        Returns
        -------

        Speckle_C               : torch.tensor
                                    The computed speckle contrast.
        """
        mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
        var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
        Speckle_C = var / mean
        return Speckle_C

__init__(kernel_size=11, step_size=(1, 1), loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel_size
                        Convolution filter kernel size, 11 by 11 average kernel by default.
    
  • step_size
                        Convolution stride in height and width direction.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
    """
    Parameters
    ----------
    kernel_size             : torch.tensor
                                Convolution filter kernel size, 11 by 11 average kernel by default.
    step_size               : tuple
                                Convolution stride in height and width direction.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(speckle_contrast, self).__init__()
    self.device = device
    self.loss = loss
    self.step_size = step_size
    self.kernel_size = kernel_size
    self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
    self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))

forward(intensity)

Calculates the speckle contrast Loss.

Parameters:

  • intensity
                        intensity of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, intensity):
    """
    Calculates the speckle contrast Loss.

    Parameters
    ----------
    intensity               : torch.tensor
                                intensity of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(intensity.shape) == 2:
        intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
    Speckle_C = self.functional_conv2d(intensity)
    loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
    return loss_value

functional_conv2d(intensity)

Calculates the speckle contrast of the intensity.

Parameters:

  • intensity
                        Intensity of the complex field.
    

Returns:

  • Speckle_C ( tensor ) –

    The computed speckle contrast.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, intensity):
    """
    Calculates the speckle contrast of the intensity.

    Parameters
    ----------
    intensity                : torch.tensor
                                Intensity of the complex field.

    Returns
    -------

    Speckle_C               : torch.tensor
                                The computed speckle contrast.
    """
    mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
    var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
    Speckle_C = var / mean
    return Speckle_C

holobeam_multiholo

Bases: Module

The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.

Parameters:

  • n_input
                Number of channels in the input.
    
  • n_hidden
                Number of channels in the hidden layers.
    
  • n_output
                Number of channels in the output layer.
    
  • device
                Default device is CPU.
    
  • reduction
                Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    
Source code in odak/learn/wave/models.py
class holobeam_multiholo(torch.nn.Module):
    """
    The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.


    Parameters
    ----------
    n_input           : int
                        Number of channels in the input.
    n_hidden          : int
                        Number of channels in the hidden layers.
    n_output          : int
                        Number of channels in the output layer.
    device            : torch.device
                        Default device is CPU.
    reduction         : str
                        Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    """
    def __init__(
                 self,
                 n_input = 1,
                 n_hidden = 16,
                 n_output = 2,
                 device = torch.device('cpu'),
                 reduction = 'sum'
                ):
        super(holobeam_multiholo, self).__init__()
        torch.random.seed()
        self.device = device
        self.reduction = reduction
        self.l2 = torch.nn.MSELoss(reduction = self.reduction)
        self.l1 = torch.nn.L1Loss(reduction = self.reduction)
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_output = n_output
        self.network = unet(
                            dimensions = self.n_hidden,
                            input_channels = self.n_input,
                            output_channels = self.n_output
                           ).to(self.device)


    def forward(self, x, test = False):
        """
        Internal function representing the forward model.
        """
        if test:
            torch.no_grad()
        y = self.network.forward(x) 
        phase_low = y[:, 0].unsqueeze(1)
        phase_high = y[:, 1].unsqueeze(1)
        phase_only = torch.zeros_like(phase_low)
        phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
        phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
        phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
        phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
        return phase_only


    def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
        """
        Internal function for evaluating.
        """
        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
        return loss


    def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
        """
        Function to train the weights of the multi layer perceptron.

        Parameters
        ----------
        dataloader       : torch.utils.data.DataLoader
                           Data loader.
        number_of_epochs : int
                           Number of epochs.
        learning_rate    : float
                           Learning rate of the optimizer.
        directory        : str
                           Output directory.
        save_at_every    : int
                           Save the model at every given epoch count.
        """
        t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        for i in t_epoch:
            epoch_loss = 0.
            t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
            for j, data in enumerate(t_data):
                self.optimizer.zero_grad()
                images, holograms = data
                estimates = self.forward(images)
                loss = self.evaluate(estimates, holograms)
                loss.backward(retain_graph=True)
                self.optimizer.step()
                description = 'Loss:{:.4f}'.format(loss.item())
                t_data.set_description(description)
                epoch_loss += float(loss.item()) / dataloader.__len__()
            description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
            t_epoch.set_description(description)
            if i % save_at_every == 0:
                self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
        self.save_weights(filename='{}/weights.pt'.format(directory))
        print(description)


    def save_weights(self, filename = './weights.pt'):
        """
        Function to save the current weights of the multi layer perceptron to a file.
        Parameters
        ----------
        filename        : str
                          Filename.
        """
        torch.save(self.network.state_dict(), os.path.expanduser(filename))


    def load_weights(self, filename = './weights.pt'):
        """
        Function to load weights for this multi layer perceptron from a file.
        Parameters
        ----------
        filename        : str
                          Filename.
        """
        self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
        self.network.eval()

evaluate(input_data, ground_truth, weights=[1.0, 0.1])

Internal function for evaluating.

Source code in odak/learn/wave/models.py
def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
    """
    Internal function for evaluating.
    """
    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
    return loss

fit(dataloader, number_of_epochs=100, learning_rate=1e-05, directory='./output', save_at_every=100)

Function to train the weights of the multi layer perceptron.

Parameters:

  • dataloader
               Data loader.
    
  • number_of_epochs (int, default: 100 ) –
               Number of epochs.
    
  • learning_rate
               Learning rate of the optimizer.
    
  • directory
               Output directory.
    
  • save_at_every
               Save the model at every given epoch count.
    
Source code in odak/learn/wave/models.py
def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
    """
    Function to train the weights of the multi layer perceptron.

    Parameters
    ----------
    dataloader       : torch.utils.data.DataLoader
                       Data loader.
    number_of_epochs : int
                       Number of epochs.
    learning_rate    : float
                       Learning rate of the optimizer.
    directory        : str
                       Output directory.
    save_at_every    : int
                       Save the model at every given epoch count.
    """
    t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
    for i in t_epoch:
        epoch_loss = 0.
        t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
        for j, data in enumerate(t_data):
            self.optimizer.zero_grad()
            images, holograms = data
            estimates = self.forward(images)
            loss = self.evaluate(estimates, holograms)
            loss.backward(retain_graph=True)
            self.optimizer.step()
            description = 'Loss:{:.4f}'.format(loss.item())
            t_data.set_description(description)
            epoch_loss += float(loss.item()) / dataloader.__len__()
        description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
        t_epoch.set_description(description)
        if i % save_at_every == 0:
            self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
    self.save_weights(filename='{}/weights.pt'.format(directory))
    print(description)

forward(x, test=False)

Internal function representing the forward model.

Source code in odak/learn/wave/models.py
def forward(self, x, test = False):
    """
    Internal function representing the forward model.
    """
    if test:
        torch.no_grad()
    y = self.network.forward(x) 
    phase_low = y[:, 0].unsqueeze(1)
    phase_high = y[:, 1].unsqueeze(1)
    phase_only = torch.zeros_like(phase_low)
    phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
    phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
    phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
    phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
    return phase_only

load_weights(filename='./weights.pt')

Function to load weights for this multi layer perceptron from a file.

Parameters:

  • filename
              Filename.
    
Source code in odak/learn/wave/models.py
def load_weights(self, filename = './weights.pt'):
    """
    Function to load weights for this multi layer perceptron from a file.
    Parameters
    ----------
    filename        : str
                      Filename.
    """
    self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
    self.network.eval()

save_weights(filename='./weights.pt')

Function to save the current weights of the multi layer perceptron to a file.

Parameters:

  • filename
              Filename.
    
Source code in odak/learn/wave/models.py
def save_weights(self, filename = './weights.pt'):
    """
    Function to save the current weights of the multi layer perceptron to a file.
    Parameters
    ----------
    filename        : str
                      Filename.
    """
    torch.save(self.network.state_dict(), os.path.expanduser(filename))

multi_color_hologram_optimizer

A class for optimizing single or multi color holograms. For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.

Source code in odak/learn/wave/optimizers.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
class multi_color_hologram_optimizer():
    """
    A class for optimizing single or multi color holograms.
    For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.
    """
    def __init__(self,
                 wavelengths,
                 resolution,
                 targets,
                 propagator,
                 number_of_frames = 3,
                 number_of_depth_layers = 1,
                 learning_rate = 2e-2,
                 learning_rate_floor = 5e-3,
                 double_phase = True,
                 scale_factor = 1,
                 method = 'multi-color',
                 channel_power_filename = '',
                 device = None,
                 loss_function = None,
                 peak_amplitude = 1.0,
                 optimize_peak_amplitude = False,
                 img_loss_thres = 2e-3,
                 reduction = 'sum'
                ):
        self.device = device
        if isinstance(self.device, type(None)):
            self.device = torch.device("cpu")
        torch.cuda.empty_cache()
        torch.random.seed()
        self.wavelengths = wavelengths
        self.resolution = resolution
        self.targets = targets
        if propagator.propagation_type != 'Impulse Response Fresnel':
            scale_factor = 1
        self.scale_factor = scale_factor
        self.propagator = propagator
        self.learning_rate = learning_rate
        self.learning_rate_floor = learning_rate_floor
        self.number_of_channels = len(self.wavelengths)
        self.number_of_frames = number_of_frames
        self.number_of_depth_layers = number_of_depth_layers
        self.double_phase = double_phase
        self.channel_power_filename = channel_power_filename
        self.method = method
        if self.method != 'conventional' and self.method != 'multi-color':
           logging.warning('Unknown optimization method. Options are conventional or multi-color.')
           import sys
           sys.exit()
        self.peak_amplitude = peak_amplitude
        self.optimize_peak_amplitude = optimize_peak_amplitude
        if self.optimize_peak_amplitude:
            self.init_peak_amplitude_scale()
        self.img_loss_thres = img_loss_thres
        self.kernels = []
        self.init_phase()
        self.init_channel_power()
        self.init_loss_function(loss_function, reduction = reduction)
        self.init_amplitude()
        self.init_phase_scale()


    def init_peak_amplitude_scale(self):
        """
        Internal function to set the phase scale.
        """
        self.peak_amplitude = torch.tensor(
                                           self.peak_amplitude,
                                           requires_grad = True,
                                           device=self.device
                                          )


    def init_phase_scale(self):
        """
        Internal function to set the phase scale.
        """
        if self.method == 'conventional':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )
        if self.method == 'multi-color':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )


    def init_amplitude(self):
        """
        Internal function to set the amplitude of the illumination source.
        """
        self.amplitude = torch.zeros(
                                     self.resolution[0] * self.scale_factor,
                                     self.resolution[1] * self.scale_factor,
                                     requires_grad = False,
                                     device = self.device
                                    )
        self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.


    def init_phase(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        self.phase = torch.zeros(
                                 self.number_of_frames,
                                 self.resolution[0],
                                 self.resolution[1],
                                 device = self.device,
                                 requires_grad = True
                                )
        self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)


    def init_channel_power(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        if self.method == 'conventional':
            logging.warning('Scheme: Conventional')
            self.channel_power = torch.eye(
                                           self.number_of_frames,
                                           self.number_of_channels,
                                           device = self.device,
                                           requires_grad = False
                                          )

        elif self.method == 'multi-color':
            logging.warning('Scheme: Multi-color')
            self.channel_power = torch.ones(
                                            self.number_of_frames,
                                            self.number_of_channels,
                                            device = self.device,
                                            requires_grad = True
                                           )
        if self.channel_power_filename != '':
            self.channel_power = torch_load(self.channel_power_filename).to(self.device)
            self.channel_power.requires_grad = False
            self.channel_power[self.channel_power < 0.] = 0.
            self.channel_power[self.channel_power > 1.] = 1.
            if self.method == 'multi-color':
                self.channel_power.requires_grad = True
            if self.method == 'conventional':
                self.channel_power = torch.abs(torch.cos(self.channel_power))
            logging.warning('Channel powers:')
            logging.warning(self.channel_power)
            logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
        self.propagator.set_laser_powers(self.channel_power)



    def init_optimizer(self):
        """
        Internal function to set the optimizer.
        """
        optimization_variables = [self.phase, self.offset]
        if self.optimize_peak_amplitude:
            optimization_variables.append(self.peak_amplitude)
        if self.method == 'multi-color':
            optimization_variables.append(self.propagator.channel_power)
        self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)


    def init_loss_function(self, loss_function, reduction = 'sum'):
        """
        Internal function to set the loss function.
        """
        self.l2_loss = torch.nn.MSELoss(reduction = reduction)
        self.loss_type = 'custom'
        self.loss_function = loss_function
        if isinstance(self.loss_function, type(None)):
            self.loss_type = 'conventional'
            self.loss_function = torch.nn.MSELoss(reduction = reduction)



    def evaluate(self, input_image, target_image, plane_id = 0):
        """
        Internal function to evaluate the loss.
        """
        if self.loss_type == 'conventional':
            loss = self.loss_function(input_image, target_image)
        elif self.loss_type == 'custom':
            loss = 0
            for i in range(len(self.wavelengths)):
                loss += self.loss_function(
                                           input_image[i],
                                           target_image[i],
                                           plane_id = plane_id
                                          )
        return loss


    def double_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase similarly to double phase encoding.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_zero_mean = phase - torch.mean(phase)
        phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)
        phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)
        loss = multi_scale_total_variation_loss(phase_low, levels = 6)
        loss += multi_scale_total_variation_loss(phase_high, levels = 6)
        loss += torch.std(phase_low)
        loss += torch.std(phase_high)
        phase_only = torch.zeros_like(phase)
        phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
        phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
        phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
        phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
        return phase_only, loss


    def direct_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)
        loss = multi_scale_total_variation_loss(phase, levels = 6)
        loss += multi_scale_total_variation_loss(phase_offset, levels = 6)
        return phase_only, loss


    def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
        """
        Function to optimize multiplane phase-only holograms using stochastic gradient descent.

        Parameters
        ----------
        number_of_iterations       : float
                                     Number of iterations.
        weights                    : list
                                     Weights used in the loss function.

        Returns
        -------
        hologram                   : torch.tensor
                                     Optimised hologram.
        """
        hologram_phases = torch.zeros(
                                      self.number_of_frames,
                                      self.resolution[0],
                                      self.resolution[1],
                                      device = self.device
                                     )
        t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
        if self.optimize_peak_amplitude:
            peak_amp_cache = self.peak_amplitude.item()
        for step in t:
            for g in self.optimizer.param_groups:
                g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
                if g['lr'] < self.learning_rate_floor:
                    g['lr'] = self.learning_rate_floor
                learning_rate = g['lr']
            total_loss = 0
            t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
            for depth_id in t_depth:
                self.optimizer.zero_grad()
                depth_target = self.targets[depth_id]
                reconstruction_intensities = torch.zeros(
                                                         self.number_of_frames,
                                                         self.number_of_channels,
                                                         self.resolution[0] * self.scale_factor,
                                                         self.resolution[1] * self.scale_factor,
                                                         device = self.device
                                                        )
                loss_variation_hologram = 0
                laser_powers = self.propagator.get_laser_powers()
                for frame_id in range(self.number_of_frames):
                    if self.double_phase:
                        phase, loss_phase = self.double_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    else:
                        phase, loss_phase = self.direct_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    loss_variation_hologram += loss_phase
                    for channel_id in range(self.number_of_channels):
                        phase_scaled = torch.zeros_like(self.amplitude)
                        phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
                        laser_power = laser_powers[frame_id][channel_id]
                        hologram = generate_complex_field(
                                                          laser_power * self.amplitude,
                                                          phase_scaled * self.phase_scale[channel_id]
                                                         )
                        reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                        intensity = calculate_amplitude(reconstruction_field) ** 2
                        reconstruction_intensities[frame_id, channel_id] += intensity
                    hologram_phases[frame_id] = phase.detach().clone()
                loss_laser = self.l2_loss(
                                          torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
                                          torch.sum(laser_powers, dim = 0)
                                         )
                loss_laser += self.l2_loss(
                                           torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
                                           torch.sum(laser_powers).view(1,)
                                          )
                loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
                reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
                loss_image = self.evaluate(
                                           reconstruction_intensity,
                                           depth_target * self.peak_amplitude,
                                           plane_id = depth_id
                                          )
                loss = weights[0] * loss_image
                loss += weights[1] * loss_laser
                loss += weights[2] * loss_variation_hologram
                include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
                if include_pa_loss_flag:
                    loss -= self.peak_amplitude * 1.
                if self.method == 'conventional':
                    loss.backward()
                else:
                    loss.backward(retain_graph = True)
                self.optimizer.step()
                if include_pa_loss_flag:
                    peak_amp_cache = self.peak_amplitude.item()
                else:
                    with torch.no_grad():
                        if self.optimize_peak_amplitude:
                            self.peak_amplitude.view([1])[0] = peak_amp_cache
                total_loss += loss.detach().item()
                loss_image = loss_image.detach()
                del loss_laser
                del loss_variation_hologram
                del loss
            description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
            t.set_description(description)
            del total_loss
            del loss_image
            del reconstruction_field
            del reconstruction_intensities
            del intensity
            del phase
            del hologram
        logging.warning(description)
        return hologram_phases.detach()


    def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):
        """
        Function to optimize multiplane phase-only holograms.

        Parameters
        ----------
        number_of_iterations       : int
                                     Number of iterations.
        weights                    : list
                                     Loss weights.
        bits                       : int
                                     Quantizes the hologram using the given bits and reconstructs.

        Returns
        -------
        hologram_phases            : torch.tensor
                                     Phases of the optimized phase-only hologram.
        reconstruction_intensities : torch.tensor
                                     Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
        """
        self.init_optimizer()
        hologram_phases = self.gradient_descent(
                                                number_of_iterations=number_of_iterations,
                                                weights=weights
                                               )
        hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi
        torch.no_grad()
        reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
        laser_powers = self.