Skip to content

odak.learn.wave

angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution with Angular Spectrum method for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def angular_spectrum(
                     field,
                     k,
                     distance,
                     dx,
                     wavelength,
                     zero_padding = False,
                     aperture = 1.
                    ):
    """
    A definition to calculate convolution with Angular Spectrum method for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field
               A complex field.
               The expected size is [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def band_limited_angular_spectrum(
                                  field,
                                  k,
                                  distance,
                                  dx,
                                  wavelength,
                                  zero_padding = False,
                                  aperture = 1.
                                 ):
    """
    A definition to calculate bandlimited angular spectrum based beam propagation. For more 
    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.

    Parameters
    ----------
    field            : torch.complex
                       A complex field.
                       The expected size is [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Bandlimited Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

custom(field, kernel, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • kernel
               Custom complex kernel for beam propagation.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def custom(
           field,
           kernel,
           zero_padding = False,
           aperture = 1.
          ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    kernel           : torch.complex
                       Custom complex kernel for beam propagation.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    if type(kernel) == type(None):
        H = torch.ones(field.shape).to(field.device)
    else:
        H = kernel * aperture
    U1 = torch.fft.fftshift(torch.fft.fft2(field)) * aperture
    if zero_padding == False:
        U2 = H * U1
    elif zero_padding == True:
        U2 = zero_pad(H * U1)
    result = torch.fft.ifft2(torch.fft.ifftshift(U2))
    return result

fraunhofer(field, k, distance, dx, wavelength)

A definition to calculate light transport usin Fraunhofer approximation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def fraunhofer(
               field,
               k,
               distance,
               dx,
               wavelength
              ):
    """
    A definition to calculate light transport usin Fraunhofer approximation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).
    """
    nv, nu = field.shape[-1], field.shape[-2]
    x = torch.linspace(-nv*dx/2, nv*dx/2, nv, dtype=torch.float32)
    y = torch.linspace(-nu*dx/2, nu*dx/2, nu, dtype=torch.float32)
    Y, X = torch.meshgrid(y, x, indexing='ij')
    Z = torch.pow(X, 2) + torch.pow(Y, 2)
    c = 1. / (1j * wavelength * distance) * torch.exp(1j * k * 0.5 / distance * Z)
    c = c.to(field.device)
    result = c * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(field))) * dx ** 2
    return result

gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel')

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

Parameters:

  • field
               Complex field (MxN).
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • slm_range
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Type of the propagation (see odak.learn.wave.propagate_beam).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

  • reconstruction ( cfloat ) –

    Calculated reconstruction using calculated hologram.

Source code in odak/learn/wave/classical.py
def gerchberg_saxton(
                     field,
                     n_iterations,
                     distance,
                     dx,
                     wavelength,
                     slm_range = 6.28,
                     propagation_type = 'Transfer Function Fresnel'
                    ):
    """
    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

    Parameters
    ----------
    field            : torch.cfloat
                       Complex field (MxN).
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    slm_range        : float
                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    propagation_type : str
                       Type of the propagation (see odak.learn.wave.propagate_beam).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    reconstruction   : torch.cfloat
                       Calculated reconstruction using calculated hologram. 
    """
    k = wavenumber(wavelength)
    reconstruction = field
    for i in range(n_iterations):
        hologram = propagate_beam(
            reconstruction, k, -distance, dx, wavelength, propagation_type)
        reconstruction = propagate_beam(
            hologram, k, distance, dx, wavelength, propagation_type)
        reconstruction = set_amplitude(reconstruction, field)
    reconstruction = propagate_beam(
        hologram, k, distance, dx, wavelength, propagation_type)
    return hologram, reconstruction

get_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_angular_spectrum_kernel(
                                nu,
                                nv,
                                dx = 8e-6,
                                wavelength = 515e-9,
                                distance = 0.,
                                device = torch.device('cpu')
                               ):
    """
    Helper function for odak.learn.wave.angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
    H = H.to(device)
    return H

get_band_limited_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.band_limited_angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_band_limited_angular_spectrum_kernel(
                                             nu,
                                             nv,
                                             dx = 8e-6,
                                             wavelength = 515e-9,
                                             distance = 0.,
                                             device = torch.device('cpu')
                                            ):
    """
    Helper function for odak.learn.wave.band_limited_angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    """
    x = dx * float(nu)
    y = dx * float(nv)
    fx = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * x),
                         1 / (2 * dx) - 0.5 / (2 * x),
                         nu,
                         dtype = torch.float32,
                         device = device
                        )
    fy = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * y),
                        1 / (2 * dx) - 0.5 / (2 * y),
                        nv,
                        dtype = torch.float32,
                        device = device
                       )
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    HH_exp = 2 * torch.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))
    distance = torch.tensor([distance], device = device)
    H_exp = torch.mul(HH_exp, distance)
    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength
    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength
    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()
    H = generate_complex_field(H_filter, H_exp)
    return H

get_impulse_response_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), scale=1, aperture_samples=[20, 20, 5, 5])

Helper function for odak.learn.wave.impulse_response_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • scale
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    
  • aperture_samples
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_impulse_response_fresnel_kernel(
                                        nu,
                                        nv,
                                        dx = 8e-6,
                                        wavelength = 515e-9,
                                        distance = 0.,
                                        device = torch.device('cpu'),
                                        scale = 1,
                                        aperture_samples = [20, 20, 5, 5]
                                       ):
    """
    Helper function for odak.learn.wave.impulse_response_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    scale              : int
                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    aperture_samples   : list
                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.

    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    """
    k = wavenumber(wavelength)
    distance = torch.as_tensor(distance, device = device)
    length_x, length_y = (torch.tensor(dx * nu, device = device), torch.tensor(dx * nv, device = device))
    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
    X, Y = torch.meshgrid(x, y, indexing = 'ij')
    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device)
    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device)
    h = torch.zeros(nu * scale, nv * scale, dtype = torch.complex64, device = device)
    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device)
    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device)
    for wx in tqdm(wxs):
        for wy in wys:
            for px in pxs:
                for py in pys:
                    r = (X + px - wx) ** 2 + (Y + py - wy) ** 2
                    h += 1. / (1j * wavelength * distance) * torch.exp(1j * k / (2 * distance) * r) 
    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
    return H

get_incoherent_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_incoherent_angular_spectrum_kernel(
                                           nu,
                                           nv,
                                           dx = 8e-6,
                                           wavelength = 515e-9,
                                           distance = 0.,
                                           device = torch.device('cpu')
                                          ):
    """
    Helper function for odak.learn.wave.angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
    H_ptime = correlation_2d(H, H)
    H = H_ptime.to(device)
    return H

get_light_kernels(wavelengths, distances, pixel_pitches, resolution=[1080, 1920], resolution_factor=1, samples=[50, 50, 5, 5], propagation_type='Bandlimited Angular Spectrum', kernel_type='spatial', device=torch.device('cpu'))

Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

Parameters:

  • wavelengths
                 A list of wavelengths.
    
  • distances
                 A list of propagation distances.
    
  • pixel_pitches
                 A list of pixel_pitches.
    
  • resolution
                 Resolution of the light transport kernel.
    
  • resolution_factor
                 If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
    
  • samples
                 If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
    
  • propagation_type
                 Propagation type. For more, see odak.learn.wave.propagate_beam().
    
  • kernel_type
                 If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
    
  • device
                 Device used for computation (i.e., cpu, cuda).
    

Returns:

  • light_kernels_amplitude ( tensor ) –

    Amplitudes of the light kernels generated [w x d x p x m x n].

  • light_kernels_phase ( tensor ) –

    Phases of the light kernels generated [w x d x p x m x n].

  • light_kernels_complex ( tensor ) –

    Complex light kernels generated [w x d x p x m x n].

  • light_parameters ( tensor ) –

    Parameters of each pixel in light_kernels* [w x d x p x m x n x 5]. Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.

Source code in odak/learn/wave/classical.py
def get_light_kernels(
                      wavelengths,
                      distances,
                      pixel_pitches,
                      resolution = [1080, 1920],
                      resolution_factor = 1,
                      samples = [50, 50, 5, 5],
                      propagation_type = 'Bandlimited Angular Spectrum',
                      kernel_type = 'spatial',
                      device = torch.device('cpu')
                     ):
    """
    Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

    Parameters
    ----------
    wavelengths        : list
                         A list of wavelengths.
    distances          : list
                         A list of propagation distances.
    pixel_pitches      : list
                         A list of pixel_pitches.
    resolution         : list
                         Resolution of the light transport kernel.
    resolution_factor  : int
                         If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
    samples            : list
                         If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
    propagation_type   : str
                         Propagation type. For more, see odak.learn.wave.propagate_beam().
    kernel_type        : str
                         If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
    device             : torch.device
                         Device used for computation (i.e., cpu, cuda).

    Returns
    -------
    light_kernels_amplitude : torch.tensor
                              Amplitudes of the light kernels generated [w x d x p x m x n].
    light_kernels_phase     : torch.tensor
                              Phases of the light kernels generated [w x d x p x m x n].
    light_kernels_complex   : torch.tensor
                              Complex light kernels generated [w x d x p x m x n].
    light_parameters        : torch.tensor
                              Parameters of each pixel in light_kernels* [w x d x p x m x n x 5].  Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.
    """
    if propagation_type != 'Impulse Response Fresnel':
        resolution_factor = 1
    light_kernels_complex = torch.zeros(            
                                        len(wavelengths),
                                        len(distances),
                                        len(pixel_pitches),
                                        resolution[0] * resolution_factor,
                                        resolution[1] * resolution_factor,
                                        dtype = torch.complex64,
                                        device = device
                                       )
    light_parameters = torch.zeros(
                                   len(wavelengths),
                                   len(distances),
                                   len(pixel_pitches),
                                   resolution[0] * resolution_factor,
                                   resolution[1] * resolution_factor,
                                   5,
                                   dtype = torch.float32,
                                   device = device
                                  )
    for wavelength_id, distance_id, pixel_pitch_id in itertools.product(
                                                                        range(len(wavelengths)),
                                                                        range(len(distances)),
                                                                        range(len(pixel_pitches)),
                                                                       ):
        pixel_pitch = pixel_pitches[pixel_pitch_id]
        wavelength = wavelengths[wavelength_id]
        distance = distances[distance_id]
        kernel_fourier = get_propagation_kernel(
                                                nu = resolution[0],
                                                nv = resolution[1],
                                                dx = pixel_pitch,
                                                wavelength = wavelength,
                                                distance = distance,
                                                device = device,
                                                propagation_type = propagation_type,
                                                scale = resolution_factor,
                                                samples = samples
                                               )
        if kernel_type == 'spatial':
            kernel = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(kernel_fourier)))
        elif kernel_type == 'fourier':
            kernel = kernel_fourier
        else:
            logging.warning('Unknown kernel type requested.')
            raise ValueError('Unknown kernel type requested.')
        kernel_amplitude = calculate_amplitude(kernel)
        kernel_phase = calculate_phase(kernel) % (2 * torch.pi)
        light_kernels_complex[wavelength_id, distance_id, pixel_pitch_id] = kernel
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 0] = wavelength
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 1] = distance
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 2] = pixel_pitch
        x = torch.linspace(-1., 1., resolution[0] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[0]
        y = torch.linspace(-1., 1., resolution[1] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[1]
        X, Y = torch.meshgrid(x, y, indexing = 'ij')
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 3] = X
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 4] = Y
    light_kernels_amplitude = calculate_amplitude(light_kernels_complex)
    light_kernels_phase = calculate_phase(light_kernels_complex) % (2. * torch.pi)
    return light_kernels_amplitude, light_kernels_phase, light_kernels_complex, light_parameters

get_point_wise_impulse_response_fresnel_kernel(aperture_points, aperture_field, target_points, resolution, resolution_factor=1, wavelength=5.15e-07, distance=0.0, randomization=False, device=torch.device('cpu'))

This function is a freeform point spread function calculation routine for an aperture defined with a complex field, aperture_field, and locations in space, aperture_points. The point spread function is calculated over provided points, target_points. The final result is reshaped to follow the provided resolution.

Parameters:

  • aperture_points
                       Points representing an aperture in Euler space (XYZ) [m x 3].
    
  • aperture_field
                       Complex field for each point provided by `aperture_points` [1 x m].
    
  • target_points
                       Target points where the propagated field will be calculated [n x 1].
    
  • resolution
                       Final resolution that the propagated field will be reshaped [X x Y].
    
  • resolution_factor
                       Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.
    
  • wavelength
                       Wavelength in meters.
    
  • randomization
                       If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.
    
  • distance
                       Distance in meters.
    

Returns:

  • h ( float ) –

    Complex field in spatial domain.

Source code in odak/learn/wave/classical.py
def get_point_wise_impulse_response_fresnel_kernel(
                                                   aperture_points,
                                                   aperture_field,
                                                   target_points,
                                                   resolution,
                                                   resolution_factor = 1,
                                                   wavelength = 515e-9,
                                                   distance = 0.,
                                                   randomization = False,
                                                   device = torch.device('cpu')
                                                  ):
    """
    This function is a freeform point spread function calculation routine for an aperture defined with a complex field, `aperture_field`, and locations in space, `aperture_points`.
    The point spread function is calculated over provided points, `target_points`.
    The final result is reshaped to follow the provided `resolution`.

    Parameters
    ----------
    aperture_points          : torch.tensor
                               Points representing an aperture in Euler space (XYZ) [m x 3].
    aperture_field           : torch.tensor
                               Complex field for each point provided by `aperture_points` [1 x m].
    target_points            : torch.tensor
                               Target points where the propagated field will be calculated [n x 1].
    resolution               : list
                               Final resolution that the propagated field will be reshaped [X x Y].
    resolution_factor        : int
                               Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.
    wavelength               : float
                               Wavelength in meters.
    randomization            : bool
                               If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.
    distance                 : float
                               Distance in meters.

    Returns
    -------
    h                        : float
                               Complex field in spatial domain.
    """
    device = aperture_field.device
    k = wavenumber(wavelength)
    if randomization:
        pp = [
              aperture_points[:, 0].max() - aperture_points[:, 0].min(),
              aperture_points[:, 1].max() - aperture_points[:, 1].min()
             ]
        target_points[:, 0] = target_points[:, 0] - torch.randn(target_points[:, 0].shape) * pp[0]
        target_points[:, 1] = target_points[:, 1] - torch.randn(target_points[:, 1].shape) * pp[1]
    deltaX = aperture_points[:, 0].unsqueeze(0) - target_points[:, 0].unsqueeze(-1)
    deltaY = aperture_points[:, 1].unsqueeze(0) - target_points[:, 1].unsqueeze(-1)
    r = deltaX ** 2 + deltaY ** 2
    h = torch.exp(1j * k / (2 * distance) * r) * aperture_field
    h = torch.sum(h, dim = 1).reshape(resolution[0] * resolution_factor, resolution[1] * resolution_factor)
    h = 1. / (1j * wavelength * distance) * h
    return h

get_propagation_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), propagation_type='Bandlimited Angular Spectrum', scale=1, samples=[20, 20, 5, 5])

Get propagation kernel for the propagation type.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • propagation_type
                 Propagation type.
                 The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    
  • scale
                 Scale factor for scaled beam propagation.
    
  • samples
                 When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    

Returns:

  • kernel ( tensor ) –

    Complex kernel for the given propagation type.

Source code in odak/learn/wave/classical.py
def get_propagation_kernel(
                           nu, 
                           nv, 
                           dx = 8e-6, 
                           wavelength = 515e-9, 
                           distance = 0., 
                           device = torch.device('cpu'), 
                           propagation_type = 'Bandlimited Angular Spectrum', 
                           scale = 1,
                           samples = [20, 20, 5, 5]
                          ):
    """
    Get propagation kernel for the propagation type.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    propagation_type   : str
                         Propagation type.
                         The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    scale              : int
                         Scale factor for scaled beam propagation.
    samples            : list
                         When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.


    Returns
    -------
    kernel             : torch.tensor
                         Complex kernel for the given propagation type.
    """                                                      
    logging.warning('Requested propagation kernel size for %s method with %s m distance, %s m pixel pitch, %s m wavelength, %s x %s resolutions, x%s scale and %s samples.'.format(propagation_type, distance, dx, nu, nv, scale, samples))
    if propagation_type == 'Bandlimited Angular Spectrum':
        kernel = get_band_limited_angular_spectrum_kernel(
                                                          nu = nu,
                                                          nv = nv,
                                                          dx = dx,
                                                          wavelength = wavelength,
                                                          distance = distance,
                                                          device = device
                                                         )
    elif propagation_type == 'Angular Spectrum':
        kernel = get_angular_spectrum_kernel(
                                             nu = nu,
                                             nv = nv,
                                             dx = dx,
                                             wavelength = wavelength,
                                             distance = distance,
                                             device = device
                                            )
    elif propagation_type == 'Transfer Function Fresnel':
        kernel = get_transfer_function_fresnel_kernel(
                                                      nu = nu,
                                                      nv = nv,
                                                      dx = dx,
                                                      wavelength = wavelength,
                                                      distance = distance,
                                                      device = device
                                                     )
    elif propagation_type == 'Impulse Response Fresnel':
        kernel = get_impulse_response_fresnel_kernel(
                                                     nu = nu, 
                                                     nv = nv, 
                                                     dx = dx, 
                                                     wavelength = wavelength,
                                                     distance = distance,
                                                     device =  device,
                                                     scale = scale,
                                                     aperture_samples = samples
                                                    )
    elif propagation_type == 'Incoherent Angular Spectrum':
        kernel = get_incoherent_angular_spectrum_kernel(
                                                        nu = nu,
                                                        nv = nv, 
                                                        dx = dx, 
                                                        wavelength = wavelength, 
                                                        distance = distance,
                                                        device = device
                                                       )
    elif propagation_type == 'Seperable Impulse Response Fresnel':
        kernel, _, _, _ = get_seperable_impulse_response_fresnel_kernel(
                                                                        nu = nu,
                                                                        nv = nv,
                                                                        dx = dx,
                                                                        wavelength = wavelength,
                                                                        distance = distance,
                                                                        device = device,
                                                                        scale = scale,
                                                                        aperture_samples = samples
                                                                       )
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    return kernel

get_seperable_impulse_response_fresnel_kernel(nu, nv, dx=3.74e-06, wavelength=5.15e-07, distance=0.0, scale=1, aperture_samples=[50, 50, 5, 5], device=torch.device('cpu'))

Returns impulse response fresnel kernel in separable form.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • scale
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    
  • aperture_samples
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

  • h ( complex64 ) –

    Complex kernel in spatial domain.

  • h_x ( complex64 ) –

    1D complex kernel in spatial domain along X axis.

  • h_y ( complex64 ) –

    1D complex kernel in spatial domain along Y axis.

Source code in odak/learn/wave/classical.py
def get_seperable_impulse_response_fresnel_kernel(
                                                  nu,
                                                  nv,
                                                  dx = 3.74e-6,
                                                  wavelength = 515e-9,
                                                  distance = 0.,
                                                  scale = 1,
                                                  aperture_samples = [50, 50, 5, 5],
                                                  device = torch.device('cpu')
                                                 ):
    """
    Returns impulse response fresnel kernel in separable form.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    scale              : int
                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    aperture_samples   : list
                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.

    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    h                  : torch.complex64
                         Complex kernel in spatial domain.
    h_x                : torch.complex64
                         1D complex kernel in spatial domain along X axis.
    h_y                : torch.complex64
                         1D complex kernel in spatial domain along Y axis.
    """
    k = wavenumber(wavelength)
    distance = torch.as_tensor(distance, device = device)
    length_x, length_y = (
                          torch.tensor(dx * nu, device = device),
                          torch.tensor(dx * nv, device = device)
                         )
    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device).unsqueeze(0).unsqueeze(0)
    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device).unsqueeze(0).unsqueeze(-1)
    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device).unsqueeze(0).unsqueeze(-1)
    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device).unsqueeze(0).unsqueeze(0)
    wxs = (wxs - pxs).reshape(1, -1).unsqueeze(-1)
    wys = (wys - pys).reshape(1, -1).unsqueeze(1)

    X = x.unsqueeze(-1).unsqueeze(-1)
    Y = y[y.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
    r_x = (X + wxs) ** 2
    r_y = (Y + wys) ** 2
    r = r_x + r_y
    h_x = torch.exp(1j * k / (2 * distance) * r)
    h_x = torch.sum(h_x, axis = (1, 2))

    if nu != nv:
        X = x[x.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
        Y = y.unsqueeze(-1).unsqueeze(-1)
        r_x = (X + wxs) ** 2
        r_y = (Y + wys) ** 2
        r = r_x + r_y
        h_y = torch.exp(1j * k * r / (2 * distance))
        h_y = torch.sum(h_y, axis = (1, 2))
    else:
        h_y = h_x.detach().clone()
    h = torch.exp(1j * k * distance) / (1j * wavelength * distance) * h_x.unsqueeze(1) * h_y.unsqueeze(0)
    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
    return H, h, h_x, h_y

get_transfer_function_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.transfer_function_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_transfer_function_fresnel_kernel(
                                         nu,
                                         nv,
                                         dx = 8e-6,
                                         wavelength = 515e-9,
                                         distance = 0.,
                                         device = torch.device('cpu')
                                        ):
    """
    Helper function for odak.learn.wave.transfer_function_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing = 'ij')
    k = wavenumber(wavelength)
    H = torch.exp(-1j * distance * (k - torch.pi * wavelength * (FX ** 2 + FY ** 2)))
    return H

impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5])

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def impulse_response_fresnel(
                             field,
                             k,
                             distance,
                             dx,
                             wavelength,
                             zero_padding = False,
                             aperture = 1.,
                             scale = 1,
                             samples = [20, 20, 5, 5]
                            ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Impulse Response Fresnel',
                               device = field.device,
                               scale = scale,
                               samples = samples
                              )
    if scale > 1:
        field_amplitude = calculate_amplitude(field)
        field_phase = calculate_phase(field)
        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
        field_scale_phase = torch.zeros_like(field_scale_amplitude)
        field_scale_amplitude[::scale, ::scale] = field_amplitude
        field_scale_phase[::scale, ::scale] = field_phase
        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
    else:
        field_scale = field
    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
    return result

incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate incoherent beam propagation with Angular Spectrum method.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def incoherent_angular_spectrum(
                                field,
                                k,
                                distance,
                                dx,
                                wavelength,
                                zero_padding = False,
                                aperture = 1.
                               ):
    """
    A definition to calculate incoherent beam propagation with Angular Spectrum method.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Incoherent Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

point_wise(target, wavelength, distance, dx, device, lens_size=401)

Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

Parameters:

  • target
               float input target to be converted into a hologram (Target should be in range of 0 and 1).
    
  • wavelength
               Wavelength of the electric field.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • device
               Device type (cuda or cpu)`.
    
  • lens_size
               Size of lens for masking sub holograms(in pixels).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

Source code in odak/learn/wave/classical.py
def point_wise(
               target,
               wavelength,
               distance,
               dx,
               device,
               lens_size=401
              ):
    """
    Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

    Parameters
    ----------
    target           : torch.float
                       float input target to be converted into a hologram (Target should be in range of 0 and 1).
    wavelength       : float
                       Wavelength of the electric field.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    device           : torch.device
                       Device type (cuda or cpu)`.
    lens_size        : int
                       Size of lens for masking sub holograms(in pixels).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    """
    target = zero_pad(target)
    nx, ny = target.shape
    k = wavenumber(wavelength)
    ones = torch.ones(target.shape, requires_grad=False).to(device)
    x = torch.linspace(-nx/2, nx/2, nx).to(device)
    y = torch.linspace(-ny/2, ny/2, ny).to(device)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = (X**2+Y**2)**0.5
    mask = (torch.abs(Z) <= lens_size)
    mask[mask > 1] = 1
    fz = quadratic_phase_function(nx, ny, k, focal=-distance, dx=dx).to(device)
    A = torch.nan_to_num(target**0.5, nan=0.0)
    fz = mask*fz
    FA = torch.fft.fft2(torch.fft.fftshift(A))
    FFZ = torch.fft.fft2(torch.fft.fftshift(fz))
    H = torch.mul(FA, FFZ)
    hologram = torch.fft.ifftshift(torch.fft.ifft2(H))
    hologram = crop_center(hologram)
    return hologram

propagate_beam(field, k, distance, dx, wavelength, propagation_type='Bandlimited Angular Spectrum', kernel=None, zero_padding=[True, False, True], aperture=1.0, scale=1, samples=[20, 20, 5, 5])

Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • propagation_type (str, default: 'Bandlimited Angular Spectrum' ) –
               Type of the propagation.
               The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    
  • kernel
               Custom complex kernel.
    
  • zero_padding
               Zero padding the input field if the first item in the list set True.
               Zero padding in the Fourier domain if the second item in the list set to True.
               Cropping the result with half resolution if the third item in the list is set to true.
               Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    
  • aperture
               Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
               If provided as a floating point 1, there will be no aperture in Fourier domain.
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def propagate_beam(
                   field,
                   k,
                   distance,
                   dx,
                   wavelength,
                   propagation_type='Bandlimited Angular Spectrum',
                   kernel = None,
                   zero_padding = [True, False, True],
                   aperture = 1.,
                   scale = 1,
                   samples = [20, 20, 5, 5]
                  ):
    """
    Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    propagation_type : str
                       Type of the propagation.
                       The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    kernel           : torch.complex
                       Custom complex kernel.
    zero_padding     : list
                       Zero padding the input field if the first item in the list set True.
                       Zero padding in the Fourier domain if the second item in the list set to True.
                       Cropping the result with half resolution if the third item in the list is set to true.
                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    aperture         : torch.tensor
                       Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
                       If provided as a floating point 1, there will be no aperture in Fourier domain.
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    if zero_padding[0]:
        field = zero_pad(field)
    if propagation_type == 'Angular Spectrum':
        result = angular_spectrum(
                                  field = field,
                                  k = k,
                                  distance = distance,
                                  dx = dx,
                                  wavelength = wavelength,
                                  zero_padding = zero_padding[1],
                                  aperture = aperture
                                 )
    elif propagation_type == 'Bandlimited Angular Spectrum':
        result = band_limited_angular_spectrum(
                                               field = field,
                                               k = k,
                                               distance = distance,
                                               dx = dx,
                                               wavelength = wavelength,
                                               zero_padding = zero_padding[1],
                                               aperture = aperture
                                              )
    elif propagation_type == 'Impulse Response Fresnel':
        result = impulse_response_fresnel(
                                          field = field,
                                          k = k,
                                          distance = distance,
                                          dx = dx,
                                          wavelength = wavelength,
                                          zero_padding = zero_padding[1],
                                          aperture = aperture,
                                          scale = scale,
                                          samples = samples
                                         )
    elif propagation_type == 'Seperable Impulse Response Fresnel':
        result = seperable_impulse_response_fresnel(
                                                    field = field,
                                                    k = k,
                                                    distance = distance,
                                                    dx = dx,
                                                    wavelength = wavelength,
                                                    zero_padding = zero_padding[1],
                                                    aperture = aperture,
                                                    scale = scale,
                                                    samples = samples
                                                   )
    elif propagation_type == 'Transfer Function Fresnel':
        result = transfer_function_fresnel(
                                           field = field,
                                           k = k,
                                           distance = distance,
                                           dx = dx,
                                           wavelength = wavelength,
                                           zero_padding = zero_padding[1],
                                           aperture = aperture
                                          )
    elif propagation_type == 'custom':
        result = custom(
                        field = field,
                        kernel = kernel,
                        zero_padding = zero_padding[1],
                        aperture = aperture
                       )
    elif propagation_type == 'Fraunhofer':
        result = fraunhofer(
                            field = field,
                            k = k,
                            distance = distance,
                            dx = dx,
                            wavelength = wavelength
                           )
    elif propagation_type == 'Incoherent Angular Spectrum':
        result = incoherent_angular_spectrum(
                                             field = field,
                                             k = k,
                                             distance = distance,
                                             dx = dx,
                                             wavelength = wavelength,
                                             zero_padding = zero_padding[1],
                                             aperture = aperture
                                            )
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    if zero_padding[2]:
        result = crop_center(result)
    return result

seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5])

A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def seperable_impulse_response_fresnel(
                                       field,
                                       k,
                                       distance,
                                       dx,
                                       wavelength,
                                       zero_padding = False,
                                       aperture = 1.,
                                       scale = 1,
                                       samples = [20, 20, 5, 5]
                                      ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Seperable Impulse Response Fresnel',
                               device = field.device,
                               scale = scale,
                               samples = samples
                              )
    if scale > 1:
        field_amplitude = calculate_amplitude(field)
        field_phase = calculate_phase(field)
        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
        field_scale_phase = torch.zeros_like(field_scale_amplitude)
        field_scale_amplitude[::scale, ::scale] = field_amplitude
        field_scale_phase[::scale, ::scale] = field_phase
        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
    else:
        field_scale = field
    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
    return result

shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None)

Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in here and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

Parameters:

  • phase
               Phase value of a phase-only hologram.
    
  • depth_shift
               Distance in meters.
    
  • pixel_pitch
               Pixel pitch size in meters.
    
  • wavelength
               Wavelength of light.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Beam propagation type. For more see odak.learn.wave.propagate_beam().
    
  • kernel_length
               Kernel length for the Gaussian blur kernel.
    
  • sigma
               Standard deviation for the Gaussian blur kernel.
    
  • amplitude
               Amplitude value of a complex hologram.
    
Source code in odak/learn/wave/classical.py
def shift_w_double_phase(
                         phase,
                         depth_shift,
                         pixel_pitch,
                         wavelength,
                         propagation_type = 'Transfer Function Fresnel',
                         kernel_length = 4,
                         sigma = 0.5,
                         amplitude = None
                        ):
    """
    Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

    Parameters
    ----------
    phase            : torch.tensor
                       Phase value of a phase-only hologram.
    depth_shift      : float
                       Distance in meters.
    pixel_pitch      : float
                       Pixel pitch size in meters.
    wavelength       : float
                       Wavelength of light.
    propagation_type : str
                       Beam propagation type. For more see odak.learn.wave.propagate_beam().
    kernel_length    : int
                       Kernel length for the Gaussian blur kernel.
    sigma            : float
                       Standard deviation for the Gaussian blur kernel.
    amplitude        : torch.tensor
                       Amplitude value of a complex hologram.
    """
    if type(amplitude) == type(None):
        amplitude = torch.ones_like(phase)
    hologram = generate_complex_field(amplitude, phase)
    k = wavenumber(wavelength)
    hologram_padded = zero_pad(hologram)
    shifted_field_padded = propagate_beam(
                                          hologram_padded,
                                          k,
                                          depth_shift,
                                          pixel_pitch,
                                          wavelength,
                                          propagation_type
                                         )
    shifted_field = crop_center(shifted_field_padded)
    phase_shift = torch.exp(torch.tensor([-2 * torch.pi * depth_shift / wavelength]).to(phase.device))
    shift = torch.cos(phase_shift) + 1j * torch.sin(phase_shift)
    shifted_complex_hologram = shifted_field * shift

    if kernel_length > 0 and sigma >0:
        blur_kernel = generate_2d_gaussian(
                                           [kernel_length, kernel_length],
                                           [sigma, sigma]
                                          ).to(phase.device)
        blur_kernel = blur_kernel.unsqueeze(0)
        blur_kernel = blur_kernel.unsqueeze(0)
        field_imag = torch.imag(shifted_complex_hologram)
        field_real = torch.real(shifted_complex_hologram)
        field_imag = field_imag.unsqueeze(0)
        field_imag = field_imag.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_imag = torch.nn.functional.conv2d(field_imag, blur_kernel, padding='same')
        field_real = torch.nn.functional.conv2d(field_real, blur_kernel, padding='same')
        shifted_complex_hologram = torch.complex(field_real, field_imag)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)

    shifted_amplitude = calculate_amplitude(shifted_complex_hologram)
    shifted_amplitude = shifted_amplitude / torch.amax(shifted_amplitude, [0,1])

    shifted_phase = calculate_phase(shifted_complex_hologram)
    phase_zero_mean = shifted_phase - torch.mean(shifted_phase)

    phase_offset = torch.arccos(shifted_amplitude)
    phase_low = phase_zero_mean - phase_offset
    phase_high = phase_zero_mean + phase_offset

    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only

stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type='Bandlimited Angular Spectrum', n_iteration=100, loss_function=None, learning_rate=0.1)

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

  • target
                        Target field amplitude [m x n].
                        Keep the target values between zero and one.
    
  • wavelength
                        Set if the converted array requires gradient.
    
  • distance
                        Hologram plane distance wrt SLM plane.
    
  • pixel_pitch
                        SLM pixel pitch in meters.
    
  • propagation_type
                        Type of the propagation (see odak.learn.wave.propagate_beam()).
    
  • n_iteration
                        Number of iteration.
    
  • loss_function
                        If none it is set to be l2 loss.
    
  • learning_rate
                        Learning rate.
    

Returns:

  • hologram ( Tensor ) –

    Phase only hologram as torch array

  • reconstruction_intensity ( Tensor ) –

    Reconstruction as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(
                                target,
                                wavelength,
                                distance,
                                pixel_pitch,
                                propagation_type = 'Bandlimited Angular Spectrum',
                                n_iteration = 100,
                                loss_function = None,
                                learning_rate = 0.1
                               ):
    """
    Definition to generate phase and reconstruction from target image via stochastic gradient descent.

    Parameters
    ----------
    target                    : torch.Tensor
                                Target field amplitude [m x n].
                                Keep the target values between zero and one.
    wavelength                : double
                                Set if the converted array requires gradient.
    distance                  : double
                                Hologram plane distance wrt SLM plane.
    pixel_pitch               : float
                                SLM pixel pitch in meters.
    propagation_type          : str
                                Type of the propagation (see odak.learn.wave.propagate_beam()).
    n_iteration:              : int
                                Number of iteration.
    loss_function:            : function
                                If none it is set to be l2 loss.
    learning_rate             : float
                                Learning rate.

    Returns
    -------
    hologram                  : torch.Tensor
                                Phase only hologram as torch array

    reconstruction_intensity  : torch.Tensor
                                Reconstruction as torch array

    """
    phase = torch.randn_like(target, requires_grad = True)
    k = wavenumber(wavelength)
    optimizer = torch.optim.Adam([phase], lr = learning_rate)
    if type(loss_function) == type(None):
        loss_function = torch.nn.MSELoss()
    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)
    for i in t:
        optimizer.zero_grad()
        hologram = generate_complex_field(1., phase)
        reconstruction = propagate_beam(
                                        hologram, 
                                        k, 
                                        distance, 
                                        pixel_pitch, 
                                        wavelength, 
                                        propagation_type, 
                                        zero_padding = [True, False, True]
                                       )
        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
        loss = loss_function(reconstruction_intensity, target)
        description = "Loss:{:.4f}".format(loss.item())
        loss.backward(retain_graph = True)
        optimizer.step()
        t.set_description(description)
    logging.warning(description)
    torch.no_grad()
    hologram = generate_complex_field(1., phase)
    reconstruction = propagate_beam(
                                    hologram, 
                                    k, 
                                    distance, 
                                    pixel_pitch, 
                                    wavelength, 
                                    propagation_type, 
                                    zero_padding = [True, False, True]
                                   )
    return hologram, reconstruction

transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def transfer_function_fresnel(
                              field,
                              k,
                              distance,
                              dx,
                              wavelength,
                              zero_padding = False,
                              aperture = 1.
                             ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Transfer Function Fresnel',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

blazed_grating(nx, ny, levels=2, axis='x')

A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • levels
           Number of pixels.
    
  • axis
           Axis of glazed grating. It could be `x` or `y`.
    
Source code in odak/learn/wave/lens.py
def blazed_grating(nx, ny, levels = 2, axis = 'x'):
    """
    A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.


    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    levels       : int
                   Number of pixels.
    axis         : str
                   Axis of glazed grating. It could be `x` or `y`.

    """
    if levels < 2:
        levels = 2
    x = (torch.abs(torch.arange(-nx, 0)) % levels) / levels * (2 * np.pi)
    y = (torch.abs(torch.arange(-ny, 0)) % levels) / levels * (2 * np.pi)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'x':
        blazed_grating = torch.exp(1j * X)
    elif axis == 'y':
        blazed_grating = torch.exp(1j * Y)
    return blazed_grating

linear_grating(nx, ny, every=2, add=None, axis='x')

A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • every
         Add the add value at every given number.
    
  • add
         Angle to be added.
    
  • axis
         Axis eiter X,Y or both.
    

Returns:

  • field ( tensor ) –

    Linear grating term.

Source code in odak/learn/wave/lens.py
def linear_grating(nx, ny, every = 2, add = None, axis = 'x'):
    """
    A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    every      : int
                 Add the add value at every given number.
    add        : float
                 Angle to be added.
    axis       : string
                 Axis eiter X,Y or both.

    Returns
    ----------
    field      : torch.tensor
                 Linear grating term.
    """
    if isinstance(add, type(None)):
        add = np.pi
    grating = torch.zeros((nx, ny), dtype=torch.complex64)
    if axis == 'x':
        grating[::every, :] = torch.exp(torch.tensor(1j*add))
    if axis == 'y':
        grating[:, ::every] = torch.exp(torch.tensor(1j*add))
    if axis == 'xy':
        checker = np.indices((nx, ny)).sum(axis=0) % every
        checker = torch.from_numpy(checker)
        checker += 1
        checker = checker % 2
        grating = torch.exp(1j*checker*add)
    return grating

prism_grating(nx, ny, k, angle, dx=0.001, axis='x', phase_offset=0.0)

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • k
           See odak.wave.wavenumber for more.
    
  • angle
           Tilt angle of the prism in degrees.
    
  • dx
           Pixel pitch.
    
  • axis
           Axis of the prism.
    
  • phase_offset (float, default: 0.0 ) –
           Phase offset in angles. Default is zero.
    

Returns:

  • prism ( tensor ) –

    Generated phase function for a prism.

Source code in odak/learn/wave/lens.py
def prism_grating(nx, ny, k, angle, dx = 0.001, axis = 'x', phase_offset = 0.):
    """
    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    k            : odak.wave.wavenumber
                   See odak.wave.wavenumber for more.
    angle        : float
                   Tilt angle of the prism in degrees.
    dx           : float
                   Pixel pitch.
    axis         : str
                   Axis of the prism.
    phase_offset : float
                   Phase offset in angles. Default is zero.

    Returns
    ----------
    prism        : torch.tensor
                   Generated phase function for a prism.
    """
    angle = torch.deg2rad(torch.tensor([angle]))
    phase_offset = torch.deg2rad(torch.tensor([phase_offset]))
    x = torch.arange(0, nx) * dx
    y = torch.arange(0, ny) * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'y':
        phase = k * torch.sin(angle) * Y + phase_offset
        prism = torch.exp(-1j * phase)
    elif axis == 'x':
        phase = k * torch.sin(angle) * X + phase_offset
        prism = torch.exp(-1j * phase)
    return prism

quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • k
         See odak.wave.wavenumber for more.
    
  • focal
         Focal length of the quadratic phase function.
    
  • dx
         Pixel pitch.
    
  • offset
         Deviation from the center along X and Y axes.
    

Returns:

  • function ( tensor ) –

    Generated quadratic phase function.

Source code in odak/learn/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
    """ 
    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    k          : odak.wave.wavenumber
                 See odak.wave.wavenumber for more.
    focal      : float
                 Focal length of the quadratic phase function.
    dx         : float
                 Pixel pitch.
    offset     : list
                 Deviation from the center along X and Y axes.

    Returns
    -------
    function   : torch.tensor
                 Generated quadratic phase function.
    """
    size = [nx, ny]
    x = torch.linspace(-size[0] * dx / 2, size[0] * dx / 2, size[0]) - offset[1] * dx
    y = torch.linspace(-size[1] * dx / 2, size[1] * dx / 2, size[1]) - offset[0] * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = X**2 + Y**2
    qwf = torch.exp(-0.5j * k / focal * Z)
    return qwf

multiplane_loss

Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class multiplane_loss():
    """
    Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """

    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
                 multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image      : torch.tensor
                            Color target image [3 x m x n].
        target_depth      : torch.tensor
                            Monochrome target depth, same resolution as target_image.
        target_blur_size  : int
                            Maximum target blur size.
        blur_ratio        : float
                            Blur ratio, a value between zero and one.
        number_of_planes  : int
                            Number of planes.
        weights           : list
                            Weights of the loss function.
        multiplier        : float
                            Multiplier to multipy with targets.
        scheme            : str
                            The type of the loss, `naive` without defocus or `defocus` with defocus.
        reduction         : str
                            Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        device            : torch.device
                            Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.weights          = weights
        self.reduction        = reduction
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.

        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        l2 = self.weights[0] * self.loss_function(image, target)
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
        l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
        loss = l2 + l2_mask + l2_cor
        return loss

__call__(image, target, plane_id=None)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.

    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    l2 = self.weights[0] * self.loss_function(image, target)
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
    l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
    loss = l2 + l2_mask + l2_cor
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, weights=[1.0, 2.1, 0.6], multiplier=1.0, scheme='defocus', reduction='mean', device=torch.device('cpu'))

Parameters:

  • target_image
                Color target image [3 x m x n].
    
  • target_depth
                Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                Maximum target blur size.
    
  • blur_ratio
                Blur ratio, a value between zero and one.
    
  • number_of_planes
                Number of planes.
    
  • weights
                Weights of the loss function.
    
  • multiplier
                Multiplier to multipy with targets.
    
  • scheme
                The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • reduction
                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • device
                Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
             multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image      : torch.tensor
                        Color target image [3 x m x n].
    target_depth      : torch.tensor
                        Monochrome target depth, same resolution as target_image.
    target_blur_size  : int
                        Maximum target blur size.
    blur_ratio        : float
                        Blur ratio, a value between zero and one.
    number_of_planes  : int
                        Number of planes.
    weights           : list
                        Weights of the loss function.
    multiplier        : float
                        Multiplier to multipy with targets.
    scheme            : str
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    reduction         : str
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    device            : torch.device
                        Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.weights          = weights
    self.reduction        = reduction
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

perceptual_multiplane_loss

Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class perceptual_multiplane_loss():
    """
    Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """

    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
                 base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
                 additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image            : torch.tensor
                                    Color target image [3 x m x n].
        target_depth            : torch.tensor
                                    Monochrome target depth, same resolution as target_image.
        target_blur_size        : int
                                    Maximum target blur size.
        blur_ratio              : float
                                    Blur ratio, a value between zero and one.
        number_of_planes        : int
                                    Number of planes.
        multiplier              : float
                                    Multiplier to multipy with targets.
        scheme                  : str
                                    The type of the loss, `naive` without defocus or `defocus` with defocus.
        base_loss_weights       : list
                                    Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
        additional_loss_weights : dict
                                    Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
        reduction               : str
                                    Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        return_components       : bool
                                    If True (False by default), returns the components of the loss as a dict.
        device                  : torch.device
                                    Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.reduction        = reduction
        if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
            logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
            self.reduction = 'mean'
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.base_loss_weights = base_loss_weights
        self.additional_loss_weights = additional_loss_weights
        self.return_components = return_components
        self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
        self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
        for key in self.additional_loss_weights.keys():
            if self.additional_loss_weights[key]:
                if key == 'cvvdp':
                    self.cvvdp = CVVDP(device = device)
                if key == 'fvvdp':
                    self.fvvdp = FVVDP()
                if key == 'lpips':
                    self.lpips = LPIPS()
                if key == 'psnr':
                    self.psnr = PSNR()
                if key == 'ssim':
                    self.ssim = SSIM()
                if key == 'msssim':
                    self.msssim = MSSSIM()

    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.

        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        loss_components = {}
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
        l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
        l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
        loss_components['l2'] = l2
        loss_components['l2_mask'] = l2_mask
        loss_components['l2_cor'] = l2_cor
        loss = l2 + l2_mask + l2_cor

        l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
        l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
        l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
        loss_components['l1'] = l1
        loss_components['l1_mask'] = l1_mask
        loss_components['l1_cor'] = l1_cor
        loss += l1 + l1_mask + l1_cor

        for key in self.additional_loss_weights.keys():
            if self.additional_loss_weights[key]:
                if key == 'cvvdp':
                    loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
                    loss_components['cvvdp'] = loss_cvvdp
                    loss += loss_cvvdp
                if key == 'fvvdp':
                    loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
                    loss_components['fvvdp'] = loss_fvvdp
                    loss += loss_fvvdp
                if key == 'lpips':
                    loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
                    loss_components['lpips'] = loss_lpips
                    loss += loss_lpips
                if key == 'psnr':
                    loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
                    loss_components['psnr'] = loss_psnr
                    loss += loss_psnr
                if key == 'ssim':
                    loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
                    loss_components['ssim'] = loss_ssim
                    loss += loss_ssim
                if key == 'msssim':
                    loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
                    loss_components['msssim'] = loss_msssim
                    loss += loss_msssim
        if self.return_components:
            return loss, loss_components
        return loss

__call__(image, target, plane_id=None)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.

    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    loss_components = {}
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
    l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
    l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
    loss_components['l2'] = l2
    loss_components['l2_mask'] = l2_mask
    loss_components['l2_cor'] = l2_cor
    loss = l2 + l2_mask + l2_cor

    l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
    l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
    l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
    loss_components['l1'] = l1
    loss_components['l1_mask'] = l1_mask
    loss_components['l1_cor'] = l1_cor
    loss += l1 + l1_mask + l1_cor

    for key in self.additional_loss_weights.keys():
        if self.additional_loss_weights[key]:
            if key == 'cvvdp':
                loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
                loss_components['cvvdp'] = loss_cvvdp
                loss += loss_cvvdp
            if key == 'fvvdp':
                loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
                loss_components['fvvdp'] = loss_fvvdp
                loss += loss_fvvdp
            if key == 'lpips':
                loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
                loss_components['lpips'] = loss_lpips
                loss += loss_lpips
            if key == 'psnr':
                loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
                loss_components['psnr'] = loss_psnr
                loss += loss_psnr
            if key == 'ssim':
                loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
                loss_components['ssim'] = loss_ssim
                loss += loss_ssim
            if key == 'msssim':
                loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
                loss_components['msssim'] = loss_msssim
                loss += loss_msssim
    if self.return_components:
        return loss, loss_components
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, multiplier=1.0, scheme='defocus', base_loss_weights={'base_l2_loss': 1.0, 'loss_l2_mask': 1.0, 'loss_l2_cor': 1.0, 'base_l1_loss': 1.0, 'loss_l1_mask': 1.0, 'loss_l1_cor': 1.0}, additional_loss_weights={'cvvdp': 1.0}, reduction='mean', return_components=False, device=torch.device('cpu'))

Parameters:

  • target_image
                        Color target image [3 x m x n].
    
  • target_depth
                        Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                        Maximum target blur size.
    
  • blur_ratio
                        Blur ratio, a value between zero and one.
    
  • number_of_planes
                        Number of planes.
    
  • multiplier
                        Multiplier to multipy with targets.
    
  • scheme
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • base_loss_weights
                        Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
    
  • additional_loss_weights (dict, default: {'cvvdp': 1.0} ) –
                        Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
    
  • reduction
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • return_components
                        If True (False by default), returns the components of the loss as a dict.
    
  • device
                        Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
             base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
             additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image            : torch.tensor
                                Color target image [3 x m x n].
    target_depth            : torch.tensor
                                Monochrome target depth, same resolution as target_image.
    target_blur_size        : int
                                Maximum target blur size.
    blur_ratio              : float
                                Blur ratio, a value between zero and one.
    number_of_planes        : int
                                Number of planes.
    multiplier              : float
                                Multiplier to multipy with targets.
    scheme                  : str
                                The type of the loss, `naive` without defocus or `defocus` with defocus.
    base_loss_weights       : list
                                Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
    additional_loss_weights : dict
                                Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
    reduction               : str
                                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    return_components       : bool
                                If True (False by default), returns the components of the loss as a dict.
    device                  : torch.device
                                Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.reduction        = reduction
    if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
        logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
        self.reduction = 'mean'
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.base_loss_weights = base_loss_weights
    self.additional_loss_weights = additional_loss_weights
    self.return_components = return_components
    self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
    self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
    for key in self.additional_loss_weights.keys():
        if self.additional_loss_weights[key]:
            if key == 'cvvdp':
                self.cvvdp = CVVDP(device = device)
            if key == 'fvvdp':
                self.fvvdp = FVVDP()
            if key == 'lpips':
                self.lpips = LPIPS()
            if key == 'psnr':
                self.psnr = PSNR()
            if key == 'ssim':
                self.ssim = SSIM()
            if key == 'msssim':
                self.msssim = MSSSIM()

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

phase_gradient

Bases: Module

The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

Source code in odak/learn/wave/loss.py
class phase_gradient(nn.Module):

    """
    The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude. 

    This implements a convolution of the phase with a kernel.

    The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.
    """


    def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
        """
        Parameters
        ----------
        kernel                  : torch.tensor
                                    Convolution filter kernel, 3 by 3 Laplacian kernel by default.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(phase_gradient, self).__init__()
        self.device = device
        self.loss = loss
        if kernel == None:
            self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
        else:
            if len(kernel.shape) == 4:
                self.kernel = kernel
            else:
                self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
        self.kernel = Variable(self.kernel.to(self.device))


    def forward(self, phase):
        """
        Calculates the phase gradient Loss.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(phase.shape) == 2:
            phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
        edge_detect = self.functional_conv2d(phase)
        loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
        return loss_value


    def functional_conv2d(self, phase):
        """
        Calculates the gradient of the phase.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        edge_detect              : torch.tensor
                                    The computed phase gradient.
        """
        edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
        return edge_detect

__init__(kernel=None, loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel
                        Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
    """
    Parameters
    ----------
    kernel                  : torch.tensor
                                Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(phase_gradient, self).__init__()
    self.device = device
    self.loss = loss
    if kernel == None:
        self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
    else:
        if len(kernel.shape) == 4:
            self.kernel = kernel
        else:
            self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
    self.kernel = Variable(self.kernel.to(self.device))

forward(phase)

Calculates the phase gradient Loss.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, phase):
    """
    Calculates the phase gradient Loss.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(phase.shape) == 2:
        phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
    edge_detect = self.functional_conv2d(phase)
    loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
    return loss_value

functional_conv2d(phase)

Calculates the gradient of the phase.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • edge_detect ( tensor ) –

    The computed phase gradient.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, phase):
    """
    Calculates the gradient of the phase.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    edge_detect              : torch.tensor
                                The computed phase gradient.
    """
    edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
    return edge_detect

speckle_contrast

Bases: Module

The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

We refer to the following paper:

Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0.

Source code in odak/learn/wave/loss.py
class speckle_contrast(nn.Module):

    """
    The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

    We refer to the following paper:

    Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0. 
    """


    def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
        """
        Parameters
        ----------
        kernel_size             : torch.tensor
                                    Convolution filter kernel size, 11 by 11 average kernel by default.
        step_size               : tuple
                                    Convolution stride in height and width direction.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(speckle_contrast, self).__init__()
        self.device = device
        self.loss = loss
        self.step_size = step_size
        self.kernel_size = kernel_size
        self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
        self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))


    def forward(self, intensity):
        """
        Calculates the speckle contrast Loss.

        Parameters
        ----------
        intensity               : torch.tensor
                                    intensity of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(intensity.shape) == 2:
            intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
        Speckle_C = self.functional_conv2d(intensity)
        loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
        return loss_value


    def functional_conv2d(self, intensity):
        """
        Calculates the speckle contrast of the intensity.

        Parameters
        ----------
        intensity                : torch.tensor
                                    Intensity of the complex field.

        Returns
        -------

        Speckle_C               : torch.tensor
                                    The computed speckle contrast.
        """
        mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
        var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
        Speckle_C = var / mean
        return Speckle_C

__init__(kernel_size=11, step_size=(1, 1), loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel_size
                        Convolution filter kernel size, 11 by 11 average kernel by default.
    
  • step_size
                        Convolution stride in height and width direction.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
    """
    Parameters
    ----------
    kernel_size             : torch.tensor
                                Convolution filter kernel size, 11 by 11 average kernel by default.
    step_size               : tuple
                                Convolution stride in height and width direction.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(speckle_contrast, self).__init__()
    self.device = device
    self.loss = loss
    self.step_size = step_size
    self.kernel_size = kernel_size
    self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
    self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))

forward(intensity)

Calculates the speckle contrast Loss.

Parameters:

  • intensity
                        intensity of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, intensity):
    """
    Calculates the speckle contrast Loss.

    Parameters
    ----------
    intensity               : torch.tensor
                                intensity of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(intensity.shape) == 2:
        intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
    Speckle_C = self.functional_conv2d(intensity)
    loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
    return loss_value

functional_conv2d(intensity)

Calculates the speckle contrast of the intensity.

Parameters:

  • intensity
                        Intensity of the complex field.
    

Returns:

  • Speckle_C ( tensor ) –

    The computed speckle contrast.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, intensity):
    """
    Calculates the speckle contrast of the intensity.

    Parameters
    ----------
    intensity                : torch.tensor
                                Intensity of the complex field.

    Returns
    -------

    Speckle_C               : torch.tensor