Skip to content

odak.learn.wave

angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution with Angular Spectrum method for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution with Angular Spectrum method for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_angular_spectrum_kernel(
                                    field.shape[-2], 
                                    field.shape[-1], 
                                    dx = dx, 
                                    wavelength = wavelength, 
                                    distance = distance, 
                                    device = field.device
                                   )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field
               A complex field.
               The expected size is [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate bandlimited angular spectrum based beam propagation. For more 
    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.

    Parameters
    ----------
    field            : torch.complex
                       A complex field.
                       The expected size is [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    H = get_band_limited_angular_spectrum_kernel(
                                                 field.shape[-2], 
                                                 field.shape[-1], 
                                                 dx = dx, 
                                                 wavelength = wavelength, 
                                                 distance = distance, 
                                                 device = field.device
                                                )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

custom(field, kernel, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • kernel
               Custom complex kernel for beam propagation.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def custom(field, kernel, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    kernel           : torch.complex
                       Custom complex kernel for beam propagation.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    if type(kernel) == type(None):
        H = torch.zeros(field.shape).to(field.device)
    else:
        H = kernel * aperture
    U1 = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(field))) * aperture
    if zero_padding == False:
        U2 = H * U1
    elif zero_padding == True:
        U2 = zero_pad(H * U1)
    result = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U2)))
    return result

fraunhofer(field, k, distance, dx, wavelength)

A definition to calculate light transport usin Fraunhofer approximation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def fraunhofer(field, k, distance, dx, wavelength):
    """
    A definition to calculate light transport usin Fraunhofer approximation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).
    """
    nv, nu = field.shape[-1], field.shape[-2]
    x = torch.linspace(-nv*dx/2, nv*dx/2, nv, dtype=torch.float32)
    y = torch.linspace(-nu*dx/2, nu*dx/2, nu, dtype=torch.float32)
    Y, X = torch.meshgrid(y, x, indexing='ij')
    Z = torch.pow(X, 2) + torch.pow(Y, 2)
    c = 1./(1j*wavelength*distance)*torch.exp(1j*k*0.5/distance*Z)
    c = c.to(field.device)
    result = c * \
             torch.fft.ifftshift(torch.fft.fft2(
             torch.fft.fftshift(field)))*pow(dx, 2)
    return result

gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel')

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

Parameters:

  • field
               Complex field (MxN).
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • slm_range
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Type of the propagation (see odak.learn.wave.propagate_beam).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

  • reconstruction ( cfloat ) –

    Calculated reconstruction using calculated hologram.

Source code in odak/learn/wave/classical.py
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel'):
    """
    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

    Parameters
    ----------
    field            : torch.cfloat
                       Complex field (MxN).
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    slm_range        : float
                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    propagation_type : str
                       Type of the propagation (see odak.learn.wave.propagate_beam).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    reconstruction   : torch.cfloat
                       Calculated reconstruction using calculated hologram. 
    """
    k = wavenumber(wavelength)
    reconstruction = field
    for i in range(n_iterations):
        hologram = propagate_beam(
            reconstruction, k, -distance, dx, wavelength, propagation_type)
        hologram, _ = produce_phase_only_slm_pattern(hologram, slm_range)
        reconstruction = propagate_beam(
            hologram, k, distance, dx, wavelength, propagation_type)
        reconstruction = set_amplitude(reconstruction, field)
    reconstruction = propagate_beam(
        hologram, k, distance, dx, wavelength, propagation_type)
    return hologram, reconstruction

get_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. /2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. /2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    H = torch.exp(1j  * distance * (2 * (np.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
    H = H.to(device)
    return H

get_band_limited_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.band_limited_angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.band_limited_angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    x = dx * float(nu)
    y = dx * float(nv)
    fx = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * x),
                         1 / (2 * dx) - 0.5 / (2 * x),
                         nu,
                         dtype = torch.float32,
                         device = device
                        )
    fy = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * y),
                        1 / (2 * dx) - 0.5 / (2 * y),
                        nv,
                        dtype = torch.float32,
                        device = device
                       )
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    HH_exp = 2 * np.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))
    distance = torch.tensor([distance], device = device)
    H_exp = torch.mul(HH_exp, distance)
    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength
    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength
    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()
    H = generate_complex_field(H_filter, H_exp)
    return H

get_propagation_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), propagation_type='Bandlimited Angular Spectrum', scale=1)

Get propagation kernel for the propagation type.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • propagation_type
                 Propagation type.
                 The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    
  • scale
                 Scale factor for scaled beam propagation.
    

Returns:

  • kernel ( tensor ) –

    Complex kernel for the given propagation type.

Source code in odak/learn/wave/classical.py
def get_propagation_kernel(
                           nu, 
                           nv, 
                           dx = 8e-6, 
                           wavelength = 515e-9, 
                           distance = 0., 
                           device = torch.device('cpu'), 
                           propagation_type = 'Bandlimited Angular Spectrum', 
                           scale = 1
                          ):
    """
    Get propagation kernel for the propagation type.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    propagation_type   : str
                         Propagation type.
                         The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    scale              : int
                         Scale factor for scaled beam propagation.


    Returns
    -------
    kernel             : torch.tensor
                         Complex kernel for the given propagation type.
    """
    if propagation_type == 'Bandlimited Angular Spectrum':
        kernel = get_band_limited_angular_spectrum_kernel(nu, nv, dx, wavelength, distance, device)
    elif propagation_type == 'Angular Spectrum':
        kernel = get_angular_spectrum_kernel(nu, nv, dx, wavelength, distance, device)
    elif propagation_type == 'Transfer Function Fresnel':
        kernel = get_transfer_function_fresnel_kernel(nu, nv, dx, wavelength, distance, device)
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    return kernel

get_transfer_function_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.transfer_function_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.transfer_function_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing = 'ij')
    k = wavenumber(wavelength)
    H = torch.exp(1j* k * distance * (1 - (FX * wavelength) ** 2 - (FY * wavelength) ** 2) ** 0.5).to(device)
    return H

point_wise(target, wavelength, distance, dx, device, lens_size=401)

Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

Parameters:

  • target
               float input target to be converted into a hologram (Target should be in range of 0 and 1).
    
  • wavelength
               Wavelength of the electric field.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • device
               Device type (cuda or cpu)`.
    
  • lens_size
               Size of lens for masking sub holograms(in pixels).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

Source code in odak/learn/wave/classical.py
def point_wise(target, wavelength, distance, dx, device, lens_size=401):
    """
    Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

    Parameters
    ----------
    target           : torch.float
                       float input target to be converted into a hologram (Target should be in range of 0 and 1).
    wavelength       : float
                       Wavelength of the electric field.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    device           : torch.device
                       Device type (cuda or cpu)`.
    lens_size        : int
                       Size of lens for masking sub holograms(in pixels).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    """
    target = zero_pad(target)
    nx, ny = target.shape
    k = wavenumber(wavelength)
    ones = torch.ones(target.shape, requires_grad=False).to(device)
    x = torch.linspace(-nx/2, nx/2, nx).to(device)
    y = torch.linspace(-ny/2, ny/2, ny).to(device)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = (X**2+Y**2)**0.5
    mask = (torch.abs(Z) <= lens_size)
    mask[mask > 1] = 1
    fz = quadratic_phase_function(nx, ny, k, focal=-distance, dx=dx).to(device)
    A = torch.nan_to_num(target**0.5, nan=0.0)
    fz = mask*fz
    FA = torch.fft.fft2(torch.fft.fftshift(A))
    FFZ = torch.fft.fft2(torch.fft.fftshift(fz))
    H = torch.mul(FA, FFZ)
    hologram = torch.fft.ifftshift(torch.fft.ifft2(H))
    hologram = crop_center(hologram)
    return hologram

propagate_beam(field, k, distance, dx, wavelength, propagation_type='Bandlimited Angular Spectrum', kernel=None, zero_padding=[True, False, True], aperture=1.0)

Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • propagation_type (str, default: 'Bandlimited Angular Spectrum' ) –
               Type of the propagation.
               The options are Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    
  • kernel
               Custom complex kernel.
    
  • zero_padding
               Zero padding the input field if the first item in the list set True.
               Zero padding in the Fourier domain if the second item in the list set to True.
               Cropping the result with half resolution if the third item in the list is set to true. 
               Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def propagate_beam(
                   field, 
                   k, 
                   distance, 
                   dx, 
                   wavelength, 
                   propagation_type='Bandlimited Angular Spectrum', 
                   kernel = None, 
                   zero_padding = [True, False, True],
                   aperture = 1.
                  ):
    """
    Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    propagation_type : str
                       Type of the propagation.
                       The options are Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    kernel           : torch.complex
                       Custom complex kernel.
    zero_padding     : list
                       Zero padding the input field if the first item in the list set True.
                       Zero padding in the Fourier domain if the second item in the list set to True.
                       Cropping the result with half resolution if the third item in the list is set to true. 
                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.

    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    if zero_padding[0]:
        field = zero_pad(field)
    if propagation_type == 'Angular Spectrum':
        result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Bandlimited Angular Spectrum':
        result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Transfer Function Fresnel':
        result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'custom':
        result = custom(field, kernel, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Fraunhofer':
        result = fraunhofer(field, k, distance, dx, wavelength)
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    if zero_padding[2]:
        result = crop_center(result)
    return result

shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None)

Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in here and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

Parameters:

  • phase
               Phase value of a phase-only hologram.
    
  • depth_shift
               Distance in meters.
    
  • pixel_pitch
               Pixel pitch size in meters.
    
  • wavelength
               Wavelength of light.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Beam propagation type. For more see odak.learn.wave.propagate_beam().
    
  • kernel_length
               Kernel length for the Gaussian blur kernel.
    
  • sigma
               Standard deviation for the Gaussian blur kernel.
    
  • amplitude
               Amplitude value of a complex hologram.
    
Source code in odak/learn/wave/classical.py
def shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None):
    """
    Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

    Parameters
    ----------
    phase            : torch.tensor
                       Phase value of a phase-only hologram.
    depth_shift      : float
                       Distance in meters.
    pixel_pitch      : float
                       Pixel pitch size in meters.
    wavelength       : float
                       Wavelength of light.
    propagation_type : str
                       Beam propagation type. For more see odak.learn.wave.propagate_beam().
    kernel_length    : int
                       Kernel length for the Gaussian blur kernel.
    sigma            : float
                       Standard deviation for the Gaussian blur kernel.
    amplitude        : torch.tensor
                       Amplitude value of a complex hologram.
    """
    if type(amplitude) == type(None):
        amplitude = torch.ones_like(phase)
    hologram = generate_complex_field(amplitude, phase)
    k = wavenumber(wavelength)
    hologram_padded = zero_pad(hologram)
    shifted_field_padded = propagate_beam(
                                          hologram_padded,
                                          k,
                                          depth_shift,
                                          pixel_pitch,
                                          wavelength,
                                          propagation_type
                                         )
    shifted_field = crop_center(shifted_field_padded)
    phase_shift = torch.exp(torch.tensor([-2 * np.pi * depth_shift / wavelength]).to(phase.device))
    shift = torch.cos(phase_shift) + 1j * torch.sin(phase_shift)
    shifted_complex_hologram = shifted_field * shift

    if kernel_length > 0 and sigma >0:
        blur_kernel = generate_2d_gaussian(
                                           [kernel_length, kernel_length],
                                           [sigma, sigma]
                                          ).to(phase.device)
        blur_kernel = blur_kernel.unsqueeze(0)
        blur_kernel = blur_kernel.unsqueeze(0)
        field_imag = torch.imag(shifted_complex_hologram)
        field_real = torch.real(shifted_complex_hologram)
        field_imag = field_imag.unsqueeze(0)
        field_imag = field_imag.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_imag = torch.nn.functional.conv2d(field_imag, blur_kernel, padding='same')
        field_real = torch.nn.functional.conv2d(field_real, blur_kernel, padding='same')
        shifted_complex_hologram = torch.complex(field_real, field_imag)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)

    shifted_amplitude = calculate_amplitude(shifted_complex_hologram)
    shifted_amplitude = shifted_amplitude / torch.amax(shifted_amplitude, [0,1])

    shifted_phase = calculate_phase(shifted_complex_hologram)
    phase_zero_mean = shifted_phase - torch.mean(shifted_phase)

    phase_offset = torch.arccos(shifted_amplitude)
    phase_low = phase_zero_mean - phase_offset
    phase_high = phase_zero_mean + phase_offset

    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only

stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type='Bandlimited Angular Spectrum', n_iteration=100, loss_function=None, learning_rate=0.1)

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

  • target
                        Target field amplitude [m x n].
                        Keep the target values between zero and one.
    
  • wavelength
                        Set if the converted array requires gradient.
    
  • distance
                        Hologram plane distance wrt SLM plane.
    
  • pixel_pitch
                        SLM pixel pitch in meters.
    
  • propagation_type
                        Type of the propagation (see odak.learn.wave.propagate_beam()).
    
  • n_iteration
                        Number of iteration.
    
  • loss_function
                        If none it is set to be l2 loss.
    
  • learning_rate
                        Learning rate.
    

Returns:

  • hologram ( Tensor ) –

    Phase only hologram as torch array

reconstruction_intensity : torch.Tensor Reconstruction as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):
    """
    Definition to generate phase and reconstruction from target image via stochastic gradient descent.

    Parameters
    ----------
    target                    : torch.Tensor
                                Target field amplitude [m x n].
                                Keep the target values between zero and one.
    wavelength                : double
                                Set if the converted array requires gradient.
    distance                  : double
                                Hologram plane distance wrt SLM plane.
    pixel_pitch               : float
                                SLM pixel pitch in meters.
    propagation_type          : str
                                Type of the propagation (see odak.learn.wave.propagate_beam()).
    n_iteration:              : int
                                Number of iteration.
    loss_function:            : function
                                If none it is set to be l2 loss.
    learning_rate             : float
                                Learning rate.

    Returns
    -------
    hologram                  : torch.Tensor
                                Phase only hologram as torch array

    reconstruction_intensity  : torch.Tensor
                                Reconstruction as torch array

    """
    phase = torch.randn_like(target, requires_grad = True)
    k = wavenumber(wavelength)
    optimizer = torch.optim.Adam([phase], lr = learning_rate)
    if type(loss_function) == type(None):
        loss_function = torch.nn.MSELoss()
    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)
    for i in t:
        optimizer.zero_grad()
        hologram = generate_complex_field(1., phase)
        reconstruction = propagate_beam(
                                        hologram, 
                                        k, 
                                        distance, 
                                        pixel_pitch, 
                                        wavelength, 
                                        propagation_type, 
                                        zero_padding = [True, False, True]
                                       )
        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
        loss = loss_function(reconstruction_intensity, target)
        description = "Loss:{:.4f}".format(loss.item())
        loss.backward(retain_graph = True)
        optimizer.step()
        t.set_description(description)
    print(description)
    torch.no_grad()
    hologram = generate_complex_field(1., phase)
    reconstruction = propagate_beam(
                                    hologram, 
                                    k, 
                                    distance, 
                                    pixel_pitch, 
                                    wavelength, 
                                    propagation_type, 
                                    zero_padding = [True, False, True]
                                   )
    return hologram, reconstruction

transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_transfer_function_fresnel_kernel(
                                             field.shape[-2], 
                                             field.shape[-1], 
                                             dx = dx, 
                                             wavelength = wavelength, 
                                             distance = distance, 
                                             device = field.device
                                            )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

holographic_display

A class for simulating a holographic display.

Source code in odak/learn/wave/hardware.py
class holographic_display():
    """
    A class for simulating a holographic display.
    """
    def __init__(self, 
                 wavelengths, 
                 pixel_pitch = 3.74e-6,
                 resolution = [1920, 1080], 
                 volume_depth = 0.01,
                 number_of_depth_layers = 10,
                 image_location_offset = 0.005,
                 pinhole_size = 1500,
                 pad = [True, True],
                 illumination = None,
                 propagation_type = 'Bandlimited Angular Spectrum',
                 device = None
                ):
        """
        Parameters
        ----------
        wavelengths            : list
                                 List of wavelengths in meters (e.g., 531e-9).
        pixel_pitch            : float
                                 Pixel pitch in meters (e.g., 8e-6).
        resolution             : list
                                 Resolution (e.g., 1920 x 1080).
        volume_depth           : float
                                 Volume depth in meters.
        number_of_depth_layers : int
                                 Number of depth layers.
        image_location_offset  : float
                                 Image location offset in depth.
        pinhole_size           : int
                                 Size of the pinhole aperture in pixel in a 4f imaging system.
        pad                    : list
                                 Set it to list of True bools for zeropadding and cropping each time propagating (avoiding aliasing).
        illumination           : torch.tensor
                                 Provide the amplitude profile of the illumination source.
        device                 : torch.device
                                 Device to be used (e.g., cuda, cpu).
        """
        self.device = device
        if isinstance(self.device, type(None)):
            self.device = torch.device("cpu")
        self.pad = pad
        self.wavelengths = wavelengths
        self.resolution = resolution
        self.pixel_pitch = pixel_pitch
        self.volume_depth = volume_depth
        self.image_location_offset = torch.tensor(image_location_offset, device = device)
        self.number_of_depth_layers = number_of_depth_layers
        self.number_of_wavelengths = len(self.wavelengths)
        self.propagation_type = propagation_type
        self.pinhole_size = pinhole_size
        self.init_distances()
        self.init_amplitude(illumination)
        self.init_aperture()
        self.generate_kernels()


    def init_aperture(self):
        """
        Internal function to initialize aperture.
        """
        self.aperture = circular_binary_mask(
                                             self.resolution[0] * 2,
                                             self.resolution[1] * 2,
                                             self.pinhole_size,
                                            ).to(self.device) * 1.


    def init_amplitude(self, illumination):
        """
        Internal function to set the amplitude of the illumination source.
        """
        self.amplitude = torch.ones(
                                    self.resolution[0],
                                    self.resolution[1],
                                    requires_grad = False,
                                    device = self.device
                                   )
        if not isinstance(illumination, type(None)):
            self.amplitude = illumination


    def init_distances(self):
        """
        Internal function to set the image plane distances.
        """
        if self.number_of_depth_layers > 1:
            self.distances = torch.linspace(
                                            -self.volume_depth / 2., 
                                            self.volume_depth / 2., 
                                            self.number_of_depth_layers,
                                            device = self.device
                                           ) + self.image_location_offset
        else:
            self.distances = torch.tensor([self.image_location_offset], device = self.device)


    def forward(self, input_field, wavelength_id, depth_id):
        """

        Function that represents the forward model in hologram optimization.

        Parameters
        ----------
        input_field         : torch.tensor
                              Input complex input field.
        wavelength_id       : int
                              Identifying the color primary to be used.
        depth_id            : int
                              Identifying the depth layer to be used.

        Returns
        -------
        output_field        : torch.tensor
                              Propagated output complex field.
        """
        if self.pad[0]:
            input_field_padded = zero_pad(input_field)
        else:
            input_field_padded = input_field
        H = self.kernels[depth_id, wavelength_id].detach().clone()
        U_I = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(input_field_padded)))
        U_O = (U_I * self.aperture) * H
        output_field_padded = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U_O)))
        if self.pad[1]:
            output_field = crop_center(output_field_padded)
        else:
            output_field = output_field_padded
        return output_field


    def generate_kernels(self):
        """
        Internal function to generate light transport kernels.
        """
        if self.pad[0]:
            multiplier = 2
        else:
            multiplier = 1         
        self.kernels = torch.zeros(
                                   self.number_of_depth_layers,
                                   self.number_of_wavelengths,
                                   self.resolution[0] * multiplier,
                                   self.resolution[1] * multiplier,
                                   device = self.device,
                                   dtype = torch.complex64
                                  )
        for distance_id, distance in enumerate(self.distances):
            for wavelength_id, wavelength in enumerate(self.wavelengths):
                 self.kernels[distance_id, wavelength_id] = get_propagation_kernel(
                                                                                   nu = self.kernels.shape[-2],
                                                                                   nv = self.kernels.shape[-1],
                                                                                   dx = self.pixel_pitch, 
                                                                                   wavelength = wavelength, 
                                                                                   distance = distance,
                                                                                   device = self.device,
                                                                                   propagation_type = self.propagation_type
                                                                                  )


    def reconstruct(self, hologram_phases, laser_powers):
        """
        Internal function to reconstruct a given hologram.


        Parameters
        ----------
        hologram_phases            : torch.tensor
                                     A monochrome hologram phase [m x n].
        laser_powers               : torch.tensor
                                     Laser powers for each hologram phase.
                                     Values must be between zero and one.

        Returns
        -------
        reconstruction_intensities : torch.tensor
                                     Reconstructed frames [w x k x l x m x n].
                                     First dimension represents the number of frames.
                                     Second dimension represents the depth layers.
                                     Third dimension is for the color primaries (each wavelength provided).
        """
        self.number_of_frames = hologram_phases.shape[0]
        reconstruction_intensities = torch.zeros(
                                                 self.number_of_frames,
                                                 self.number_of_depth_layers,
                                                 self.number_of_wavelengths,
                                                 self.resolution[0],
                                                 self.resolution[1],
                                                 device = self.device
                                                )
        for frame_id in range(self.number_of_frames): 
            for depth_id in range(self.number_of_depth_layers): 
                for wavelength_id in range(self.number_of_wavelengths):
                    laser_power = laser_powers[frame_id][wavelength_id]
                    hologram = generate_complex_field(
                                                      laser_power * self.amplitude, 
                                                      hologram_phases[frame_id]
                                                     )
                    reconstruction_field = self.forward(hologram, wavelength_id, depth_id)
                    reconstruction_intensities[frame_id, depth_id, wavelength_id] = calculate_amplitude(reconstruction_field) ** 2
        return reconstruction_intensities

__init__(wavelengths, pixel_pitch=3.74e-06, resolution=[1920, 1080], volume_depth=0.01, number_of_depth_layers=10, image_location_offset=0.005, pinhole_size=1500, pad=[True, True], illumination=None, propagation_type='Bandlimited Angular Spectrum', device=None)

Parameters:

  • wavelengths
                     List of wavelengths in meters (e.g., 531e-9).
    
  • pixel_pitch
                     Pixel pitch in meters (e.g., 8e-6).
    
  • resolution
                     Resolution (e.g., 1920 x 1080).
    
  • volume_depth
                     Volume depth in meters.
    
  • number_of_depth_layers (int, default: 10 ) –
                     Number of depth layers.
    
  • image_location_offset
                     Image location offset in depth.
    
  • pinhole_size
                     Size of the pinhole aperture in pixel in a 4f imaging system.
    
  • pad
                     Set it to list of True bools for zeropadding and cropping each time propagating (avoiding aliasing).
    
  • illumination
                     Provide the amplitude profile of the illumination source.
    
  • device
                     Device to be used (e.g., cuda, cpu).
    
Source code in odak/learn/wave/hardware.py
def __init__(self, 
             wavelengths, 
             pixel_pitch = 3.74e-6,
             resolution = [1920, 1080], 
             volume_depth = 0.01,
             number_of_depth_layers = 10,
             image_location_offset = 0.005,
             pinhole_size = 1500,
             pad = [True, True],
             illumination = None,
             propagation_type = 'Bandlimited Angular Spectrum',
             device = None
            ):
    """
    Parameters
    ----------
    wavelengths            : list
                             List of wavelengths in meters (e.g., 531e-9).
    pixel_pitch            : float
                             Pixel pitch in meters (e.g., 8e-6).
    resolution             : list
                             Resolution (e.g., 1920 x 1080).
    volume_depth           : float
                             Volume depth in meters.
    number_of_depth_layers : int
                             Number of depth layers.
    image_location_offset  : float
                             Image location offset in depth.
    pinhole_size           : int
                             Size of the pinhole aperture in pixel in a 4f imaging system.
    pad                    : list
                             Set it to list of True bools for zeropadding and cropping each time propagating (avoiding aliasing).
    illumination           : torch.tensor
                             Provide the amplitude profile of the illumination source.
    device                 : torch.device
                             Device to be used (e.g., cuda, cpu).
    """
    self.device = device
    if isinstance(self.device, type(None)):
        self.device = torch.device("cpu")
    self.pad = pad
    self.wavelengths = wavelengths
    self.resolution = resolution
    self.pixel_pitch = pixel_pitch
    self.volume_depth = volume_depth
    self.image_location_offset = torch.tensor(image_location_offset, device = device)
    self.number_of_depth_layers = number_of_depth_layers
    self.number_of_wavelengths = len(self.wavelengths)
    self.propagation_type = propagation_type
    self.pinhole_size = pinhole_size
    self.init_distances()
    self.init_amplitude(illumination)
    self.init_aperture()
    self.generate_kernels()

forward(input_field, wavelength_id, depth_id)

Function that represents the forward model in hologram optimization.

Parameters:

  • input_field
                  Input complex input field.
    
  • wavelength_id
                  Identifying the color primary to be used.
    
  • depth_id
                  Identifying the depth layer to be used.
    

Returns:

  • output_field ( tensor ) –

    Propagated output complex field.

Source code in odak/learn/wave/hardware.py
def forward(self, input_field, wavelength_id, depth_id):
    """

    Function that represents the forward model in hologram optimization.

    Parameters
    ----------
    input_field         : torch.tensor
                          Input complex input field.
    wavelength_id       : int
                          Identifying the color primary to be used.
    depth_id            : int
                          Identifying the depth layer to be used.

    Returns
    -------
    output_field        : torch.tensor
                          Propagated output complex field.
    """
    if self.pad[0]:
        input_field_padded = zero_pad(input_field)
    else:
        input_field_padded = input_field
    H = self.kernels[depth_id, wavelength_id].detach().clone()
    U_I = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(input_field_padded)))
    U_O = (U_I * self.aperture) * H
    output_field_padded = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U_O)))
    if self.pad[1]:
        output_field = crop_center(output_field_padded)
    else:
        output_field = output_field_padded
    return output_field

generate_kernels()

Internal function to generate light transport kernels.

Source code in odak/learn/wave/hardware.py
def generate_kernels(self):
    """
    Internal function to generate light transport kernels.
    """
    if self.pad[0]:
        multiplier = 2
    else:
        multiplier = 1         
    self.kernels = torch.zeros(
                               self.number_of_depth_layers,
                               self.number_of_wavelengths,
                               self.resolution[0] * multiplier,
                               self.resolution[1] * multiplier,
                               device = self.device,
                               dtype = torch.complex64
                              )
    for distance_id, distance in enumerate(self.distances):
        for wavelength_id, wavelength in enumerate(self.wavelengths):
             self.kernels[distance_id, wavelength_id] = get_propagation_kernel(
                                                                               nu = self.kernels.shape[-2],
                                                                               nv = self.kernels.shape[-1],
                                                                               dx = self.pixel_pitch, 
                                                                               wavelength = wavelength, 
                                                                               distance = distance,
                                                                               device = self.device,
                                                                               propagation_type = self.propagation_type
                                                                              )

init_amplitude(illumination)

Internal function to set the amplitude of the illumination source.

Source code in odak/learn/wave/hardware.py
def init_amplitude(self, illumination):
    """
    Internal function to set the amplitude of the illumination source.
    """
    self.amplitude = torch.ones(
                                self.resolution[0],
                                self.resolution[1],
                                requires_grad = False,
                                device = self.device
                               )
    if not isinstance(illumination, type(None)):
        self.amplitude = illumination

init_aperture()

Internal function to initialize aperture.

Source code in odak/learn/wave/hardware.py
def init_aperture(self):
    """
    Internal function to initialize aperture.
    """
    self.aperture = circular_binary_mask(
                                         self.resolution[0] * 2,
                                         self.resolution[1] * 2,
                                         self.pinhole_size,
                                        ).to(self.device) * 1.

init_distances()

Internal function to set the image plane distances.

Source code in odak/learn/wave/hardware.py
def init_distances(self):
    """
    Internal function to set the image plane distances.
    """
    if self.number_of_depth_layers > 1:
        self.distances = torch.linspace(
                                        -self.volume_depth / 2., 
                                        self.volume_depth / 2., 
                                        self.number_of_depth_layers,
                                        device = self.device
                                       ) + self.image_location_offset
    else:
        self.distances = torch.tensor([self.image_location_offset], device = self.device)

reconstruct(hologram_phases, laser_powers)

Internal function to reconstruct a given hologram.

Parameters:

  • hologram_phases
                         A monochrome hologram phase [m x n].
    
  • laser_powers
                         Laser powers for each hologram phase.
                         Values must be between zero and one.
    

Returns:

  • reconstruction_intensities ( tensor ) –

    Reconstructed frames [w x k x l x m x n]. First dimension represents the number of frames. Second dimension represents the depth layers. Third dimension is for the color primaries (each wavelength provided).

Source code in odak/learn/wave/hardware.py
def reconstruct(self, hologram_phases, laser_powers):
    """
    Internal function to reconstruct a given hologram.


    Parameters
    ----------
    hologram_phases            : torch.tensor
                                 A monochrome hologram phase [m x n].
    laser_powers               : torch.tensor
                                 Laser powers for each hologram phase.
                                 Values must be between zero and one.

    Returns
    -------
    reconstruction_intensities : torch.tensor
                                 Reconstructed frames [w x k x l x m x n].
                                 First dimension represents the number of frames.
                                 Second dimension represents the depth layers.
                                 Third dimension is for the color primaries (each wavelength provided).
    """
    self.number_of_frames = hologram_phases.shape[0]
    reconstruction_intensities = torch.zeros(
                                             self.number_of_frames,
                                             self.number_of_depth_layers,
                                             self.number_of_wavelengths,
                                             self.resolution[0],
                                             self.resolution[1],
                                             device = self.device
                                            )
    for frame_id in range(self.number_of_frames): 
        for depth_id in range(self.number_of_depth_layers): 
            for wavelength_id in range(self.number_of_wavelengths):
                laser_power = laser_powers[frame_id][wavelength_id]
                hologram = generate_complex_field(
                                                  laser_power * self.amplitude, 
                                                  hologram_phases[frame_id]
                                                 )
                reconstruction_field = self.forward(hologram, wavelength_id, depth_id)
                reconstruction_intensities[frame_id, depth_id, wavelength_id] = calculate_amplitude(reconstruction_field) ** 2
    return reconstruction_intensities

blazed_grating(nx, ny, levels=2, axis='x')

A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • levels
           Number of pixels.
    
  • axis
           Axis of glazed grating. It could be `x` or `y`.
    
Source code in odak/learn/wave/lens.py
def blazed_grating(nx, ny, levels = 2, axis = 'x'):
    """
    A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.


    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    levels       : int
                   Number of pixels.
    axis         : str
                   Axis of glazed grating. It could be `x` or `y`.

    """
    if levels < 2:
        levels = 2
    x = (torch.abs(torch.arange(-nx, 0)) % levels) / levels * (2 * np.pi)
    y = (torch.abs(torch.arange(-ny, 0)) % levels) / levels * (2 * np.pi)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'x':
        blazed_grating = torch.exp(1j * X)
    elif axis == 'y':
        blazed_grating = torch.exp(1j * Y)
    return blazed_grating

linear_grating(nx, ny, every=2, add=None, axis='x')

A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • every
         Add the add value at every given number.
    
  • add
         Angle to be added.
    
  • axis
         Axis eiter X,Y or both.
    

Returns:

  • field ( tensor ) –

    Linear grating term.

Source code in odak/learn/wave/lens.py
def linear_grating(nx, ny, every = 2, add = None, axis = 'x'):
    """
    A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    every      : int
                 Add the add value at every given number.
    add        : float
                 Angle to be added.
    axis       : string
                 Axis eiter X,Y or both.

    Returns
    ----------
    field      : torch.tensor
                 Linear grating term.
    """
    if isinstance(add, type(None)):
        add = np.pi
    grating = torch.zeros((nx, ny), dtype=torch.complex64)
    if axis == 'x':
        grating[::every, :] = torch.exp(torch.tensor(1j*add))
    if axis == 'y':
        grating[:, ::every] = torch.exp(torch.tensor(1j*add))
    if axis == 'xy':
        checker = np.indices((nx, ny)).sum(axis=0) % every
        checker = torch.from_numpy(checker)
        checker += 1
        checker = checker % 2
        grating = torch.exp(1j*checker*add)
    return grating

prism_grating(nx, ny, k, angle, dx=0.001, axis='x', phase_offset=0.0)

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • k
           See odak.wave.wavenumber for more.
    
  • angle
           Tilt angle of the prism in degrees.
    
  • dx
           Pixel pitch.
    
  • axis
           Axis of the prism.
    
  • phase_offset (float, default: 0.0 ) –
           Phase offset in angles. Default is zero.
    

Returns:

  • prism ( tensor ) –

    Generated phase function for a prism.

Source code in odak/learn/wave/lens.py
def prism_grating(nx, ny, k, angle, dx = 0.001, axis = 'x', phase_offset = 0.):
    """
    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    k            : odak.wave.wavenumber
                   See odak.wave.wavenumber for more.
    angle        : float
                   Tilt angle of the prism in degrees.
    dx           : float
                   Pixel pitch.
    axis         : str
                   Axis of the prism.
    phase_offset : float
                   Phase offset in angles. Default is zero.

    Returns
    ----------
    prism        : torch.tensor
                   Generated phase function for a prism.
    """
    angle = torch.deg2rad(torch.tensor([angle]))
    phase_offset = torch.deg2rad(torch.tensor([phase_offset]))
    x = torch.arange(0, nx) * dx
    y = torch.arange(0, ny) * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'y':
        phase = k * torch.sin(angle) * Y + phase_offset
        prism = torch.exp(-1j * phase)
    elif axis == 'x':
        phase = k * torch.sin(angle) * X + phase_offset
        prism = torch.exp(-1j * phase)
    return prism

quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • k
         See odak.wave.wavenumber for more.
    
  • focal
         Focal length of the quadratic phase function.
    
  • dx
         Pixel pitch.
    
  • offset
         Deviation from the center along X and Y axes.
    

Returns:

  • function ( tensor ) –

    Generated quadratic phase function.

Source code in odak/learn/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
    """ 
    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    k          : odak.wave.wavenumber
                 See odak.wave.wavenumber for more.
    focal      : float
                 Focal length of the quadratic phase function.
    dx         : float
                 Pixel pitch.
    offset     : list
                 Deviation from the center along X and Y axes.

    Returns
    -------
    function   : torch.tensor
                 Generated quadratic phase function.
    """
    size = [nx, ny]
    x = torch.linspace(-size[0] * dx / 2, size[0] * dx / 2, size[0]) - offset[1] * dx
    y = torch.linspace(-size[1] * dx / 2, size[1] * dx / 2, size[1]) - offset[0] * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = X**2 + Y**2
    qwf = torch.exp(-0.5j * k / focal * Z)
    return qwf

multiplane_loss

Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class multiplane_loss():
    """
    Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """


    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6, 0.], 
                 multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image      : torch.tensor
                            Color target image [3 x m x n].
        target_depth      : torch.tensor
                            Monochrome target depth, same resolution as target_image.
        target_blur_size  : int
                            Maximum target blur size.
        blur_ratio        : float
                            Blur ratio, a value between zero and one.
        number_of_planes  : int
                            Number of planes.
        weights           : list
                            Weights of the loss function.
        multiplier        : float
                            Multiplier to multipy with targets.
        scheme            : str
                            The type of the loss, `naive` without defocus or `defocus` with defocus.
        reduction         : str
                            Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        device            : torch.device
                            Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.weights          = weights
        self.reduction        = reduction
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.loss_function = torch.nn.MSELoss(reduction = self.reduction)


    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
                sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.

        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        l2 = self.weights[0] * self.loss_function(image, target)
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
        l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
        loss = l2 + l2_mask + l2_cor
        return loss

__call__(image, target, plane_id=None)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.

    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    l2 = self.weights[0] * self.loss_function(image, target)
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
    l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
    loss = l2 + l2_mask + l2_cor
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, weights=[1.0, 2.1, 0.6, 0.0], multiplier=1.0, scheme='defocus', reduction='mean', device=torch.device('cpu'))

Parameters:

  • target_image
                Color target image [3 x m x n].
    
  • target_depth
                Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                Maximum target blur size.
    
  • blur_ratio
                Blur ratio, a value between zero and one.
    
  • number_of_planes
                Number of planes.
    
  • weights
                Weights of the loss function.
    
  • multiplier
                Multiplier to multipy with targets.
    
  • scheme
                The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • reduction
                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • device
                Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6, 0.], 
             multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image      : torch.tensor
                        Color target image [3 x m x n].
    target_depth      : torch.tensor
                        Monochrome target depth, same resolution as target_image.
    target_blur_size  : int
                        Maximum target blur size.
    blur_ratio        : float
                        Blur ratio, a value between zero and one.
    number_of_planes  : int
                        Number of planes.
    weights           : list
                        Weights of the loss function.
    multiplier        : float
                        Multiplier to multipy with targets.
    scheme            : str
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    reduction         : str
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    device            : torch.device
                        Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.weights          = weights
    self.reduction        = reduction
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
            sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

phase_gradient

Bases: Module

The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

Source code in odak/learn/wave/loss.py
class phase_gradient(nn.Module):

    """
    The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude. 

    This implements a convolution of the phase with a kernel.

    The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.
    """


    def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
        """
        Parameters
        ----------
        kernel                  : torch.tensor
                                    Convolution filter kernel, 3 by 3 Laplacian kernel by default.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(phase_gradient, self).__init__()
        self.device = device
        self.loss = loss
        if kernel == None:
            self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
        else:
            if len(kernel.shape) == 4:
                self.kernel = kernel
            else:
                self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
        self.kernel = Variable(self.kernel.to(self.device))


    def forward(self, phase):
        """
        Calculates the phase gradient Loss.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(phase.shape) == 2:
            phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
        edge_detect = self.functional_conv2d(phase)
        loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
        return loss_value


    def functional_conv2d(self, phase):
        """
        Calculates the gradient of the phase.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        edge_detect              : torch.tensor
                                    The computed phase gradient.
        """
        edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
        return edge_detect

__init__(kernel=None, loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel
                        Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
    """
    Parameters
    ----------
    kernel                  : torch.tensor
                                Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(phase_gradient, self).__init__()
    self.device = device
    self.loss = loss
    if kernel == None:
        self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
    else:
        if len(kernel.shape) == 4:
            self.kernel = kernel
        else:
            self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
    self.kernel = Variable(self.kernel.to(self.device))

forward(phase)

Calculates the phase gradient Loss.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, phase):
    """
    Calculates the phase gradient Loss.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(phase.shape) == 2:
        phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
    edge_detect = self.functional_conv2d(phase)
    loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
    return loss_value

functional_conv2d(phase)

Calculates the gradient of the phase.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • edge_detect ( tensor ) –

    The computed phase gradient.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, phase):
    """
    Calculates the gradient of the phase.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    edge_detect              : torch.tensor
                                The computed phase gradient.
    """
    edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
    return edge_detect

speckle_contrast

Bases: Module

The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

We refer to the following paper:

Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0.

Source code in odak/learn/wave/loss.py
class speckle_contrast(nn.Module):

    """
    The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

    We refer to the following paper:

    Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0. 
    """


    def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
        """
        Parameters
        ----------
        kernel_size             : torch.tensor
                                    Convolution filter kernel size, 11 by 11 average kernel by default.
        step_size               : tuple
                                    Convolution stride in height and width direction.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(speckle_contrast, self).__init__()
        self.device = device
        self.loss = loss
        self.step_size = step_size
        self.kernel_size = kernel_size
        self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
        self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))


    def forward(self, intensity):
        """
        Calculates the speckle contrast Loss.

        Parameters
        ----------
        intensity               : torch.tensor
                                    intensity of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(intensity.shape) == 2:
            intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
        Speckle_C = self.functional_conv2d(intensity)
        loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
        return loss_value


    def functional_conv2d(self, intensity):
        """
        Calculates the speckle contrast of the intensity.

        Parameters
        ----------
        intensity                : torch.tensor
                                    Intensity of the complex field.

        Returns
        -------

        Speckle_C               : torch.tensor
                                    The computed speckle contrast.
        """
        mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
        var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
        Speckle_C = var / mean
        return Speckle_C

__init__(kernel_size=11, step_size=(1, 1), loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel_size
                        Convolution filter kernel size, 11 by 11 average kernel by default.
    
  • step_size
                        Convolution stride in height and width direction.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
    """
    Parameters
    ----------
    kernel_size             : torch.tensor
                                Convolution filter kernel size, 11 by 11 average kernel by default.
    step_size               : tuple
                                Convolution stride in height and width direction.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(speckle_contrast, self).__init__()
    self.device = device
    self.loss = loss
    self.step_size = step_size
    self.kernel_size = kernel_size
    self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
    self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))

forward(intensity)

Calculates the speckle contrast Loss.

Parameters:

  • intensity
                        intensity of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, intensity):
    """
    Calculates the speckle contrast Loss.

    Parameters
    ----------
    intensity               : torch.tensor
                                intensity of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(intensity.shape) == 2:
        intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
    Speckle_C = self.functional_conv2d(intensity)
    loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
    return loss_value

functional_conv2d(intensity)

Calculates the speckle contrast of the intensity.

Parameters:

  • intensity
                        Intensity of the complex field.
    

Returns:

  • Speckle_C ( tensor ) –

    The computed speckle contrast.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, intensity):
    """
    Calculates the speckle contrast of the intensity.

    Parameters
    ----------
    intensity                : torch.tensor
                                Intensity of the complex field.

    Returns
    -------

    Speckle_C               : torch.tensor
                                The computed speckle contrast.
    """
    mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
    var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
    Speckle_C = var / mean
    return Speckle_C

mixed_color_hologram_optimizer

A class for optimizing holograms.

Source code in odak/learn/wave/optimizers.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
class mixed_color_hologram_optimizer():
    """
    A class for optimizing holograms.
    """
    def __init__(self,
                 wavelengths,
                 resolution,
                 targets,
                 propagator,
                 number_of_frames = 3,
                 number_of_depth_layers = 1,
                 learning_rate = 2e-2,
                 learning_rate_floor = 5e-3,
                 double_phase = True,
                 scale_factor = 1,
                 method = 'multi-color',
                 channel_power_filename = '',
                 device = None,
                 loss_function = None,
                 peak_amplitude = 1.0,
                 optimize_peak_amplitude = False,
                 img_loss_thres = 2e-3,
                 reduction = 'sum'
                ):
        self.device = device
        if isinstance(self.device, type(None)):
            self.device = torch.device("cpu")
        torch.cuda.empty_cache()
        torch.random.seed()
        self.wavelengths = wavelengths
        self.resolution = resolution
        self.targets = targets
        self.scale_factor = scale_factor
        self.propagator = propagator
        self.learning_rate = learning_rate
        self.learning_rate_floor = learning_rate_floor
        self.number_of_channels = len(self.wavelengths)
        self.number_of_frames = number_of_frames
        self.number_of_depth_layers = number_of_depth_layers
        self.double_phase = double_phase
        self.channel_power_filename = channel_power_filename
        self.method = method
        self.upsample = torch.nn.Upsample(scale_factor = self.scale_factor, mode = 'nearest')
        if self.method != 'conventional' and self.method != 'multi-color':
           logging.warning('Unknown optimization method. Options are conventional or multi-color.')
           import sys
           sys.exit()
        self.peak_amplitude = peak_amplitude
        self.optimize_peak_amplitude = optimize_peak_amplitude
        if self.optimize_peak_amplitude:
            self.init_peak_amplitude_scale()
        self.img_loss_thres = img_loss_thres
        self.kernels = []
        self.init_phase()
        self.init_channel_power()
        self.init_loss_function(loss_function, reduction = reduction)
        self.init_amplitude()
        self.init_phase_scale()


    def init_peak_amplitude_scale(self):
        """
        Internal function to set the phase scale.
        """
        self.peak_amplitude = torch.tensor(
                                           self.peak_amplitude,
                                           requires_grad = True,
                                           device=self.device
                                          )


    def init_phase_scale(self):
        """
        Internal function to set the phase scale.
        """
        if self.method == 'conventional':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )
        if self.method == 'multi-color':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )


    def init_amplitude(self):
        """
        Internal function to set the amplitude of the illumination source.
        """
        self.amplitude = torch.ones(
                                    self.resolution[0],
                                    self.resolution[1],
                                    requires_grad = False,
                                    device = self.device
                                   )


    def init_phase(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        self.phase = torch.zeros(
                                 self.number_of_frames,
                                 self.resolution[0],
                                 self.resolution[1],
                                 device = self.device,
                                 requires_grad = True
                                )
        self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)


    def init_channel_power(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        if self.method == 'conventional':
            logging.warning('Scheme: Conventional')
            self.channel_power = torch.eye(
                                           self.number_of_frames,
                                           self.number_of_channels,
                                           device = self.device,
                                           requires_grad = False
                                          )

        elif self.method == 'multi-color':
            logging.warning('Scheme: Multi-color')
            self.channel_power = torch.ones(
                                            self.number_of_frames,
                                            self.number_of_channels,
                                            device = self.device,
                                            requires_grad = True
                                           )
        if self.channel_power_filename != '':
            self.channel_power = torch_load(self.channel_power_filename).to(self.device)
            self.channel_power.requires_grad = False
            self.channel_power[self.channel_power < 0.] = 0.
            self.channel_power[self.channel_power > 1.] = 1.
            if self.method == 'multi-color':
                self.channel_power.requires_grad = True
            if self.method == 'conventional':
                self.channel_power = torch.abs(torch.cos(self.channel_power))
            logging.warning('Channel powers:')
            logging.warning(self.channel_power)
            logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
        self.propagator.set_laser_powers(self.channel_power)



    def init_optimizer(self):
        """
        Internal function to set the optimizer.
        """
        optimization_variables = [self.phase, self.offset]
        if self.optimize_peak_amplitude:
            optimization_variables.append(self.peak_amplitude)
        if self.method == 'multi-color':
            optimization_variables.append(self.propagator.channel_power)
        self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)


    def init_loss_function(self, loss_function, reduction = 'sum'):
        """
        Internal function to set the loss function.
        """
        self.l2_loss = torch.nn.MSELoss(reduction = reduction)
        self.loss_type = 'custom'
        self.loss_function = loss_function
        if isinstance(self.loss_function, type(None)):
            self.loss_type = 'conventional'
            self.loss_function = torch.nn.MSELoss(reduction = reduction)



    def evaluate(self, input_image, target_image, plane_id = 0):
        """
        Internal function to evaluate the loss.
        """
        if self.loss_type == 'conventional':
            loss = self.loss_function(input_image, target_image)
        elif self.loss_type == 'custom':
            loss = 0
            for i in range(len(self.wavelengths)):
                loss += self.loss_function(
                                           input_image[i],
                                           target_image[i],
                                           plane_id = plane_id
                                          )
        return loss



    def reconstruct(self, hologram_phases):
        """
        Internal function to reconstruct a given hologram.


        Parameters
        ----------
        hologram_phase             : torch.tensor
                                     A monochrome hologram phase [mxn].

        Returns
        -------
        reconstruction_intensities : torch.tensor
                                     Reconstructed frames.
        reconstruction_intensity   : torch.tensor
                                     Reconstructed image.
        peak_intensity             : float
                                     Peak intensity in the reconstructed image.
        """
        torch.no_grad()
        reconstruction_intensities = torch.zeros(
                                                 self.number_of_frames,
                                                 self.number_of_depth_layers,
                                                 self.number_of_channels,
                                                 self.resolution[0] * self.scale_factor,
                                                 self.resolution[1] * self.scale_factor,
                                                 device = self.device
                                                )
        for frame_id in range(self.number_of_frames):
            for depth_id in range(self.number_of_depth_layers):
                for channel_id in range(self.number_of_channels):
                    laser_power = self.propagator_get_laser_powers()[frame_id][channel_id]
                    hologram = generate_complex_field(laser_power * self.amplitude, hologram_phases[frame_id] * self.phase_scale[channel_id])
                    reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                    reconstruction_intensities[frame_id, depth_id, channel_id] = calculate_amplitude(reconstruction_field) ** 2
        return reconstruction_intensities


    def double_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase similarly to double phase encoding.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_zero_mean = phase - torch.mean(phase)
        phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)
        phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)
        loss = multi_scale_total_variation_loss(phase_low, levels = 6)
        loss += multi_scale_total_variation_loss(phase_high, levels = 6)
        loss += torch.std(phase_low)
        loss += torch.std(phase_high)
        phase_only = torch.zeros_like(phase)
        phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
        phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
        phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
        phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
        return phase_only, loss


    def direct_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)
        loss = multi_scale_total_variation_loss(phase, levels = 6)
        loss = multi_scale_total_variation_loss(phase_offset, levels = 6)
        return phase_only, loss


    def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
        """
        Function to optimize multiplane phase-only holograms using stochastic gradient descent.

        Parameters
        ----------
        number_of_iterations       : float
                                     Number of iterations.
        weights                    : list
                                     Weights used in the loss function.

        Returns
        -------
        hologram                   : torch.tensor
                                     Optimised hologram.
        """
        hologram_phases = torch.zeros(
                                      self.number_of_frames,
                                      self.resolution[0],
                                      self.resolution[1],
                                      device = self.device
                                     )
        t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
        if self.optimize_peak_amplitude:
            peak_amp_cache = self.peak_amplitude.item()
        for step in t:
            for g in self.optimizer.param_groups:
                g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
                if g['lr'] < self.learning_rate_floor:
                    g['lr'] = self.learning_rate_floor
                learning_rate = g['lr']
            total_loss = 0
            t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
            for depth_id in t_depth:
                self.optimizer.zero_grad()
                depth_target = self.targets[depth_id]
                reconstruction_intensities = torch.zeros(
                                                         self.number_of_frames,
                                                         self.number_of_channels,
                                                         self.resolution[0] * self.scale_factor,
                                                         self.resolution[1] * self.scale_factor,
                                                         device = self.device
                                                        )
                loss_variation_hologram = 0
                laser_powers = self.propagator.get_laser_powers()
                for frame_id in range(self.number_of_frames):
                    if self.double_phase:
                        phase, loss_phase = self.double_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    else:
                        phase, loss_phase = self.direct_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    loss_variation_hologram += loss_phase
                    for channel_id in range(self.number_of_channels):
                        if self.scale_factor != 1:
                            phase_scaled = torch.zeros_like(self.amplitude)
                            phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
                            self.amplitude[1::self.scale_factor, 1::self.scale_factor] = 0.
                        else:
                            phase_scaled = phase
                        scaled_phase = phase_scaled * self.phase_scale[channel_id]
                        laser_power = laser_powers[frame_id][channel_id]
                        amplitude = laser_power * self.amplitude
                        hologram = generate_complex_field(amplitude, scaled_phase)
                        reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                        intensity = calculate_amplitude(reconstruction_field) ** 2
                        reconstruction_intensities[frame_id, channel_id] += intensity
                    hologram_phases[frame_id] = phase.detach().clone()
                loss_laser = self.l2_loss(
                                          torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
                                          torch.sum(laser_powers, dim = 0)
                                         )
                loss_laser += self.l2_loss(
                                           torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
                                           torch.sum(laser_powers).view(1,)
                                          )
                loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
                reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
                loss_image = self.evaluate(
                                           reconstruction_intensity,
                                           depth_target * self.peak_amplitude,
                                           plane_id = depth_id
                                          )
                loss = weights[0] * loss_image
                loss += weights[1] * loss_laser
                loss += weights[2] * loss_variation_hologram
                include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
                if include_pa_loss_flag:
                    loss -= self.peak_amplitude * 1.
                if self.method == 'conventional':
                    loss.backward()
                else:
                    loss.backward(retain_graph = True)
                self.optimizer.step()
                if include_pa_loss_flag:
                    peak_amp_cache = self.peak_amplitude.item()
                else:
                    with torch.no_grad():
                        if self.optimize_peak_amplitude:
                            self.peak_amplitude.view([1])[0] = peak_amp_cache
                total_loss += loss.detach().item()
                loss_image = loss_image.detach()
                del loss_laser
                del loss_variation_hologram
                del loss
            description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
            t.set_description(description)
            del total_loss
            del loss_image
            del scaled_phase
            del reconstruction_field
            del reconstruction_intensities
            del intensity
            del phase
            del amplitude
            del hologram
        logging.warning(description)
        return hologram_phases.detach()


    def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]):
        """
        Function to optimize multiplane phase-only holograms.

        Parameters
        ----------
        number_of_iterations       : int
                                     Number of iterations.
        weights                    : list
                                     Loss weights.

        Returns
        -------
        hologram_phases            : torch.tensor
                                     Phases of the optimized phase-only hologram.
        reconstruction_intensities : torch.tensor
                                     Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
        """
        self.init_optimizer()
        hologram_phases = self.gradient_descent(
                                                number_of_iterations=number_of_iterations,
                                                weights=weights
                                               )
        torch.no_grad()
        reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
        laser_powers = self.propagator.get_laser_powers()
        channel_powers = self.propagator.channel_power
        logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
        logging.warning('Laser powers: {}'.format(laser_powers))
        return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)

direct_phase_constrain(phase, phase_offset)

Internal function to constrain a given phase.

Parameters:

  • phase
                         Input phase values to be constrained.
    
  • phase_offset
                         Input phase offset value.
    

Returns:

  • phase_only ( tensor ) –

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def direct_phase_constrain(self, phase, phase_offset):
    """
    Internal function to constrain a given phase.

    Parameters
    ----------
    phase                      : torch.tensor
                                 Input phase values to be constrained.
    phase_offset               : torch.tensor
                                 Input phase offset value.

    Returns
    -------
    phase_only                 : torch.tensor
                                 Constrained output phase.
    """
    phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)
    loss = multi_scale_total_variation_loss(phase, levels = 6)
    loss = multi_scale_total_variation_loss(phase_offset, levels = 6)
    return phase_only, loss

double_phase_constrain(phase, phase_offset)

Internal function to constrain a given phase similarly to double phase encoding.

Parameters:

  • phase
                         Input phase values to be constrained.
    
  • phase_offset
                         Input phase offset value.
    

Returns:

  • phase_only ( tensor ) –

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def double_phase_constrain(self, phase, phase_offset):
    """
    Internal function to constrain a given phase similarly to double phase encoding.

    Parameters
    ----------
    phase                      : torch.tensor
                                 Input phase values to be constrained.
    phase_offset               : torch.tensor
                                 Input phase offset value.

    Returns
    -------
    phase_only                 : torch.tensor
                                 Constrained output phase.
    """
    phase_zero_mean = phase - torch.mean(phase)
    phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)
    phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)
    loss = multi_scale_total_variation_loss(phase_low, levels = 6)
    loss += multi_scale_total_variation_loss(phase_high, levels = 6)
    loss += torch.std(phase_low)
    loss += torch.std(phase_high)
    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only, loss

evaluate(input_image, target_image, plane_id=0)

Internal function to evaluate the loss.

Source code in odak/learn/wave/optimizers.py
def evaluate(self, input_image, target_image, plane_id = 0):
    """
    Internal function to evaluate the loss.
    """
    if self.loss_type == 'conventional':
        loss = self.loss_function(input_image, target_image)
    elif self.loss_type == 'custom':
        loss = 0
        for i in range(len(self.wavelengths)):
            loss += self.loss_function(
                                       input_image[i],
                                       target_image[i],
                                       plane_id = plane_id
                                      )
    return loss

gradient_descent(number_of_iterations=100, weights=[1.0, 1.0, 0.0, 0.0])

Function to optimize multiplane phase-only holograms using stochastic gradient descent.

Parameters:

  • number_of_iterations
                         Number of iterations.
    
  • weights
                         Weights used in the loss function.
    

Returns:

  • hologram ( tensor ) –

    Optimised hologram.

Source code in odak/learn/wave/optimizers.py
def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
    """
    Function to optimize multiplane phase-only holograms using stochastic gradient descent.

    Parameters
    ----------
    number_of_iterations       : float
                                 Number of iterations.
    weights                    : list
                                 Weights used in the loss function.

    Returns
    -------
    hologram                   : torch.tensor
                                 Optimised hologram.
    """
    hologram_phases = torch.zeros(
                                  self.number_of_frames,
                                  self.resolution[0],
                                  self.resolution[1],
                                  device = self.device
                                 )
    t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
    if self.optimize_peak_amplitude:
        peak_amp_cache = self.peak_amplitude.item()
    for step in t:
        for g in self.optimizer.param_groups:
            g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
            if g['lr'] < self.learning_rate_floor:
                g['lr'] = self.learning_rate_floor
            learning_rate = g['lr']
        total_loss = 0
        t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
        for depth_id in t_depth:
            self.optimizer.zero_grad()
            depth_target = self.targets[depth_id]
            reconstruction_intensities = torch.zeros(
                                                     self.number_of_frames,
                                                     self.number_of_channels,
                                                     self.resolution[0] * self.scale_factor,
                                                     self.resolution[1] * self.scale_factor,
                                                     device = self.device
                                                    )
            loss_variation_hologram = 0
            laser_powers = self.propagator.get_laser_powers()
            for frame_id in range(self.number_of_frames):
                if self.double_phase:
                    phase, loss_phase = self.double_phase_constrain(
                                                                    self.phase[frame_id],
                                                                    self.offset[frame_id]
                                                                   )
                else:
                    phase, loss_phase = self.direct_phase_constrain(
                                                                    self.phase[frame_id],
                                                                    self.offset[frame_id]
                                                                   )
                loss_variation_hologram += loss_phase
                for channel_id in range(self.number_of_channels):
                    if self.scale_factor != 1:
                        phase_scaled = torch.zeros_like(self.amplitude)
                        phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
                        self.amplitude[1::self.scale_factor, 1::self.scale_factor] = 0.
                    else:
                        phase_scaled = phase
                    scaled_phase = phase_scaled * self.phase_scale[channel_id]
                    laser_power = laser_powers[frame_id][channel_id]
                    amplitude = laser_power * self.amplitude
                    hologram = generate_complex_field(amplitude, scaled_phase)
                    reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                    intensity = calculate_amplitude(reconstruction_field) ** 2
                    reconstruction_intensities[frame_id, channel_id] += intensity
                hologram_phases[frame_id] = phase.detach().clone()
            loss_laser = self.l2_loss(
                                      torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
                                      torch.sum(laser_powers, dim = 0)
                                     )
            loss_laser += self.l2_loss(
                                       torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
                                       torch.sum(laser_powers).view(1,)
                                      )
            loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
            reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
            loss_image = self.evaluate(
                                       reconstruction_intensity,
                                       depth_target * self.peak_amplitude,
                                       plane_id = depth_id
                                      )
            loss = weights[0] * loss_image
            loss += weights[1] * loss_laser
            loss += weights[2] * loss_variation_hologram
            include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
            if include_pa_loss_flag:
                loss -= self.peak_amplitude * 1.
            if self.method == 'conventional':
                loss.backward()
            else:
                loss.backward(retain_graph = True)
            self.optimizer.step()
            if include_pa_loss_flag:
                peak_amp_cache = self.peak_amplitude.item()
            else:
                with torch.no_grad():
                    if self.optimize_peak_amplitude:
                        self.peak_amplitude.view([1])[0] = peak_amp_cache
            total_loss += loss.detach().item()
            loss_image = loss_image.detach()
            del loss_laser
            del loss_variation_hologram
            del loss
        description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
        t.set_description(description)
        del total_loss
        del loss_image
        del scaled_phase
        del reconstruction_field
        del reconstruction_intensities
        del intensity
        del phase
        del amplitude
        del hologram
    logging.warning(description)
    return hologram_phases.detach()

init_amplitude()

Internal function to set the amplitude of the illumination source.

Source code in odak/learn/wave/optimizers.py
def init_amplitude(self):
    """
    Internal function to set the amplitude of the illumination source.
    """
    self.amplitude = torch.ones(
                                self.resolution[0],
                                self.resolution[1],
                                requires_grad = False,
                                device = self.device
                               )

init_channel_power()

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_channel_power(self):
    """
    Internal function to set the starting phase of the phase-only hologram.
    """
    if self.method == 'conventional':
        logging.warning('Scheme: Conventional')
        self.channel_power = torch.eye(
                                       self.number_of_frames,
                                       self.number_of_channels,
                                       device = self.device,
                                       requires_grad = False
                                      )

    elif self.method == 'multi-color':
        logging.warning('Scheme: Multi-color')
        self.channel_power = torch.ones(
                                        self.number_of_frames,
                                        self.number_of_channels,
                                        device = self.device,
                                        requires_grad = True
                                       )
    if self.channel_power_filename != '':
        self.channel_power = torch_load(self.channel_power_filename).to(self.device)
        self.channel_power.requires_grad = False
        self.channel_power[self.channel_power < 0.] = 0.
        self.channel_power[self.channel_power > 1.] = 1.
        if self.method == 'multi-color':
            self.channel_power.requires_grad = True
        if self.method == 'conventional':
            self.channel_power = torch.abs(torch.cos(self.channel_power))
        logging.warning('Channel powers:')
        logging.warning(self.channel_power)
        logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
    self.propagator.set_laser_powers(self.channel_power)

init_loss_function(loss_function, reduction='sum')

Internal function to set the loss function.

Source code in odak/learn/wave/optimizers.py
def init_loss_function(self, loss_function, reduction = 'sum'):
    """
    Internal function to set the loss function.
    """
    self.l2_loss = torch.nn.MSELoss(reduction = reduction)
    self.loss_type = 'custom'
    self.loss_function = loss_function
    if isinstance(self.loss_function, type(None)):
        self.loss_type = 'conventional'
        self.loss_function = torch.nn.MSELoss(reduction = reduction)

init_optimizer()

Internal function to set the optimizer.

Source code in odak/learn/wave/optimizers.py
def init_optimizer(self):
    """
    Internal function to set the optimizer.
    """
    optimization_variables = [self.phase, self.offset]
    if self.optimize_peak_amplitude:
        optimization_variables.append(self.peak_amplitude)
    if self.method == 'multi-color':
        optimization_variables.append(self.propagator.channel_power)
    self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)

init_peak_amplitude_scale()

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_peak_amplitude_scale(self):
    """
    Internal function to set the phase scale.
    """
    self.peak_amplitude = torch.tensor(
                                       self.peak_amplitude,
                                       requires_grad = True,
                                       device=self.device
                                      )

init_phase()

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_phase(self):
    """
    Internal function to set the starting phase of the phase-only hologram.
    """
    self.phase = torch.zeros(
                             self.number_of_frames,
                             self.resolution[0],
                             self.resolution[1],
                             device = self.device,
                             requires_grad = True
                            )
    self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)

init_phase_scale()

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_phase_scale(self):
    """
    Internal function to set the phase scale.
    """
    if self.method == 'conventional':
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )
    if self.method == 'multi-color':
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )

optimize(number_of_iterations=100, weights=[1.0, 1.0, 1.0])

Function to optimize multiplane phase-only holograms.

Parameters:

  • number_of_iterations
                         Number of iterations.
    
  • weights
                         Loss weights.
    

Returns:

  • hologram_phases ( tensor ) –

    Phases of the optimized phase-only hologram.

  • reconstruction_intensities ( tensor ) –

    Intensities of the images reconstructed at each plane with the optimized phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]):
    """
    Function to optimize multiplane phase-only holograms.

    Parameters
    ----------
    number_of_iterations       : int
                                 Number of iterations.
    weights                    : list
                                 Loss weights.

    Returns
    -------
    hologram_phases            : torch.tensor
                                 Phases of the optimized phase-only hologram.
    reconstruction_intensities : torch.tensor
                                 Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
    """
    self.init_optimizer()
    hologram_phases = self.gradient_descent(
                                            number_of_iterations=number_of_iterations,
                                            weights=weights
                                           )
    torch.no_grad()
    reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
    laser_powers = self.propagator.get_laser_powers()
    channel_powers = self.propagator.channel_power
    logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
    logging.warning('Laser powers: {}'.format(laser_powers))
    return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)

reconstruct(hologram_phases)

Internal function to reconstruct a given hologram.

Parameters:

  • hologram_phase
                         A monochrome hologram phase [mxn].
    

Returns:

  • reconstruction_intensities ( tensor ) –

    Reconstructed frames.

  • reconstruction_intensity ( tensor ) –

    Reconstructed image.

  • peak_intensity ( float ) –

    Peak intensity in the reconstructed image.

Source code in odak/learn/wave/optimizers.py
def reconstruct(self, hologram_phases):
    """
    Internal function to reconstruct a given hologram.


    Parameters
    ----------
    hologram_phase             : torch.tensor
                                 A monochrome hologram phase [mxn].

    Returns
    -------
    reconstruction_intensities : torch.tensor
                                 Reconstructed frames.
    reconstruction_intensity   : torch.tensor
                                 Reconstructed image.
    peak_intensity             : float
                                 Peak intensity in the reconstructed image.
    """
    torch.no_grad()
    reconstruction_intensities = torch.zeros(
                                             self.number_of_frames,
                                             self.number_of_depth_layers,
                                             self.number_of_channels,
                                             self.resolution[0] * self.scale_factor,
                                             self.resolution[1] * self.scale_factor,
                                             device = self.device
                                            )
    for frame_id in range(self.number_of_frames):
        for depth_id in range(self.number_of_depth_layers):
            for channel_id in range(self.number_of_channels):
                laser_power = self.propagator_get_laser_powers()[frame_id][channel_id]
                hologram = generate_complex_field(laser_power * self.amplitude, hologram_phases[frame_id] * self.phase_scale[channel_id])
                reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                reconstruction_intensities[frame_id, depth_id, channel_id] = calculate_amplitude(reconstruction_field) ** 2
    return reconstruction_intensities

multiplane_hologram_optimizer

A highly configurable class for optimizing multiplane holograms.

Source code in odak/learn/wave/optimizers.py
class multiplane_hologram_optimizer():
    """
    A highly configurable class for optimizing multiplane holograms.
    """


    def __init__(self, wavelength, image_location, 
                 image_spacing, slm_pixel_pitch,
                 slm_resolution, targets,
                 propagation_type = 'Bandlimited Angular Spectrum', 
                 propagator_type = 'back and forth',
                 number_of_iterations = 10, learning_rate = 0.1,
                 phase_initial = None, amplitude_initial = None,
                 loss_function = None,
                 mask_limits = [0.2, 0.8, 0.05, 0.95],
                 number_of_planes = 4,
                 zero_mode_distance = 0.15,
                 device = torch.device('cpu')
                ):
        self.device = device
        torch.cuda.empty_cache()
        torch.random.seed()
        self.wavelength = wavelength
        self.image_location = image_location
        self.image_spacing = image_spacing
        self.slm_resolution = slm_resolution
        self.targets = targets
        self.slm_pixel_pitch = slm_pixel_pitch
        self.number_of_planes = number_of_planes
        self.zero_mode_distance = zero_mode_distance
        self.model = propagator(
                                resolution = self.slm_resolution,
                                wavelengths = [self.wavelength,],
                                pixel_pitch = self.slm_pixel_pitch,
                                number_of_frames = 1,
                                number_of_depth_layers = self.number_of_planes,
                                volume_depth = self.number_of_planes * self.image_spacing,
                                image_location_offset = self.image_location,
                                propagation_type = propagation_type,
                                propagator_type = propagator_type,
                                back_and_forth_distance = self.zero_mode_distance,
                                device = self.device
                               )
        self.propagation_type = propagation_type
        self.mask_limits = mask_limits
        self.number_of_iterations = number_of_iterations
        self.learning_rate = learning_rate 
        self.scene_center = self.image_spacing * (self.number_of_planes - 1) / 2.
        self.wavenumber = wavenumber(self.wavelength)
        self.init_phase(phase_initial)
        self.init_amplitude(amplitude_initial)
        self.init_optimizer()
        self.init_mask()
        self.init_loss_function(loss_function)


    def init_amplitude(self, amplitude_initial):
        """
        Internal function to set the amplitude of the illumination source.
        """
        self.amplitude = amplitude_initial
        if isinstance(self.amplitude, type(None)):
            self.amplitude = torch.ones(
                                        self.slm_resolution[0],
                                        self.slm_resolution[1],
                                        requires_grad = False
                                       ).to(self.device)


    def init_phase(self, phase_initial):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        self.phase = phase_initial
        if isinstance(self.phase, type(None)):
            self.phase = torch.rand(
                                    self.slm_resolution[0],
                                    self.slm_resolution[1]
                                   ).detach().to(self.device).requires_grad_()
            self.offset = torch.rand_like(self.phase)


    def init_optimizer(self):
        """
        Internal function to set the optimizer.
        """
        parameters = [self.phase, self.offset]
        self.optimizer = torch.optim.AdamW(parameters, lr = self.learning_rate)


    def init_loss_function(self, loss_function=None, reduction='mean'):
        """
        Internal function to set the loss function.
        """
        self.loss_function = loss_function
        self.loss_type = 'other'
        if isinstance(self.loss_function, type(None)):
            self.loss_function = torch.nn.MSELoss(reduction = reduction)
            self.loss_type = 'naive'


    def init_mask(self):
        """
        Internal function to initialise the mask used in calculating the loss.
        """
        self.mask = torch.zeros(
                                self.slm_resolution[0],
                                self.slm_resolution[1],
                                requires_grad = False,
                                device = self.device
                               )
        self.mask[
                  int(self.slm_resolution[0] * self.mask_limits[0]):int(self.slm_resolution[0] * self.mask_limits[1]),
                  int(self.slm_resolution[1] * self.mask_limits[2]):int(self.slm_resolution[1] * self.mask_limits[3])
                 ] = 1


    def evaluate(self, input_image, target_image, plane_id):
        """
        Internal function to evaluate the loss.
        """
        if self.loss_type == 'naive':
            return self.loss_function(input_image, target_image)
        else:
            return self.loss_function(input_image.unsqueeze(0), target_image, plane_id)


    def optimize(self):
        """
        Function to optimize multiplane phase-only holograms.

        Returns
        -------
        hologram_phase             : torch.tensor
                                     Phase of the optimized hologram.
        hologram_amplitude         : torch.tensor
                                     Amplitude of the optimized hologram. 
        reconstruction_intensities : torch.tensor
                                     Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
        """
        hologram = self.gradient_descent()
        hologram_phase = calculate_phase(hologram)
        hologram_amplitude = calculate_amplitude(hologram)
        reconstruction_intensities = self.reconstruct(hologram_amplitude, hologram_phase)
        return hologram_phase.detach().clone(), hologram_amplitude.detach().clone(), reconstruction_intensities.detach().clone()


    def reconstruct(self, hologram_amplitude, hologram_phase):
        """
        Internal function to reconstruct a given hologram.

        Parameters
        ----------
        hologram_phase             : torch.tensor
                                     A monochrome hologram phase [mxn].

        Returns
        -------
        reconstruction_intensities : torch.tensor
                                     Reconstructed images.
        """
        hologram = generate_complex_field(hologram_amplitude, hologram_phase)
        torch.no_grad()
        reconstruction_intensities = torch.zeros(
                                                 self.number_of_planes,
                                                 self.phase.shape[0],
                                                 self.phase.shape[1],
                                                 requires_grad = False
                                                ).to(self.device)
        for plane_id in range(self.number_of_planes):
            reconstruction = self.model(hologram, channel_id = 0, depth_id = plane_id)
            reconstruction_intensities[plane_id] = calculate_amplitude(reconstruction) ** 2
        return reconstruction_intensities


    def double_phase_constrain(self, shifted_phase, phase_offset):
        """
        Function for generating double phase encoding alike phase-only holograms.

        Parameters
        ----------
        shifted_phase              : torch.tensor
                                     Input phase [m x n].
        phase_offset               : torch.tensor
                                     Input offset [m x n].

        Returns
        -------
        phase                      : torch.tensor
                                     Coded phase [m x n].
        """
        phase_zero_mean = shifted_phase - torch.mean(shifted_phase)
        phase_low = phase_zero_mean - phase_offset
        phase_high = phase_zero_mean + phase_offset
        phase = torch.zeros_like(shifted_phase)
        phase[0::2, 0::2] = phase_low[0::2, 0::2]
        phase[0::2, 1::2] = phase_high[0::2, 1::2]
        phase[1::2, 0::2] = phase_high[1::2, 0::2]
        phase[1::2, 1::2] = phase_low[1::2, 1::2]
        return phase


    def gradient_descent(self):
        """
        Function to optimize multiplane phase-only holograms using gradient descent.

        Returns
        -------
        hologram                   : torch.tensor
                                     Optimised hologram.
        """
        t = tqdm(range(self.number_of_iterations), leave = False, dynamic_ncols = True)
        for step in t:
            for plane_id in range(self.number_of_planes):
                self.optimizer.zero_grad()
                phase = self.double_phase_constrain(self.phase, self.offset)
                amplitude = self.amplitude
                hologram = generate_complex_field(amplitude, phase)
                reconstruction = self.model(hologram, channel_id = 0, depth_id = plane_id)
                reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
                loss = self.evaluate(
                                     reconstruction_intensity * self.mask,
                                     self.targets[plane_id] * self.mask,
                                     plane_id
                                    )
                loss.backward(retain_graph=True)
                self.optimizer.step()
            description = "Gradient Descent, loss:{:.4f}&