Skip to content

odak.learn.wave

angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution with Angular Spectrum method for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution with Angular Spectrum method for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_angular_spectrum_kernel(
                                    field.shape[-2], 
                                    field.shape[-1], 
                                    dx = dx, 
                                    wavelength = wavelength, 
                                    distance = distance, 
                                    device = field.device
                                   )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field
               A complex field.
               The expected size is [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate bandlimited angular spectrum based beam propagation. For more 
    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.

    Parameters
    ----------
    field            : torch.complex
                       A complex field.
                       The expected size is [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    H = get_band_limited_angular_spectrum_kernel(
                                                 field.shape[-2], 
                                                 field.shape[-1], 
                                                 dx = dx, 
                                                 wavelength = wavelength, 
                                                 distance = distance, 
                                                 device = field.device
                                                )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

custom(field, kernel, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • kernel
               Custom complex kernel for beam propagation.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def custom(field, kernel, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    kernel           : torch.complex
                       Custom complex kernel for beam propagation.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    if type(kernel) == type(None):
        H = torch.zeros(field.shape).to(field.device)
    else:
        H = kernel * aperture
    U1 = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(field))) * aperture
    if zero_padding == False:
        U2 = H * U1
    elif zero_padding == True:
        U2 = zero_pad(H * U1)
    result = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U2)))
    return result

fraunhofer(field, k, distance, dx, wavelength)

A definition to calculate light transport usin Fraunhofer approximation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def fraunhofer(field, k, distance, dx, wavelength):
    """
    A definition to calculate light transport usin Fraunhofer approximation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).
    """
    nv, nu = field.shape[-1], field.shape[-2]
    x = torch.linspace(-nv*dx/2, nv*dx/2, nv, dtype=torch.float32)
    y = torch.linspace(-nu*dx/2, nu*dx/2, nu, dtype=torch.float32)
    Y, X = torch.meshgrid(y, x, indexing='ij')
    Z = torch.pow(X, 2) + torch.pow(Y, 2)
    c = 1./(1j*wavelength*distance)*torch.exp(1j*k*0.5/distance*Z)
    c = c.to(field.device)
    result = c * \
             torch.fft.ifftshift(torch.fft.fft2(
             torch.fft.fftshift(field)))*pow(dx, 2)
    return result

gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel')

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

Parameters:

  • field
               Complex field (MxN).
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • slm_range
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Type of the propagation (see odak.learn.wave.propagate_beam).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

  • reconstruction ( cfloat ) –

    Calculated reconstruction using calculated hologram.

Source code in odak/learn/wave/classical.py
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel'):
    """
    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

    Parameters
    ----------
    field            : torch.cfloat
                       Complex field (MxN).
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    slm_range        : float
                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    propagation_type : str
                       Type of the propagation (see odak.learn.wave.propagate_beam).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    reconstruction   : torch.cfloat
                       Calculated reconstruction using calculated hologram. 
    """
    k = wavenumber(wavelength)
    reconstruction = field
    for i in range(n_iterations):
        hologram = propagate_beam(
            reconstruction, k, -distance, dx, wavelength, propagation_type)
        reconstruction = propagate_beam(
            hologram, k, distance, dx, wavelength, propagation_type)
        reconstruction = set_amplitude(reconstruction, field)
    reconstruction = propagate_beam(
        hologram, k, distance, dx, wavelength, propagation_type)
    return hologram, reconstruction

get_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. /2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. /2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    H = torch.exp(1j  * distance * (2 * (np.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
    H = H.to(device)
    return H

get_band_limited_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.band_limited_angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.band_limited_angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    x = dx * float(nu)
    y = dx * float(nv)
    fx = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * x),
                         1 / (2 * dx) - 0.5 / (2 * x),
                         nu,
                         dtype = torch.float32,
                         device = device
                        )
    fy = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * y),
                        1 / (2 * dx) - 0.5 / (2 * y),
                        nv,
                        dtype = torch.float32,
                        device = device
                       )
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    HH_exp = 2 * np.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))
    distance = torch.tensor([distance], device = device)
    H_exp = torch.mul(HH_exp, distance)
    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength
    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength
    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()
    H = generate_complex_field(H_filter, H_exp)
    return H

get_propagation_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), propagation_type='Bandlimited Angular Spectrum', scale=1)

Get propagation kernel for the propagation type.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • propagation_type
                 Propagation type.
                 The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    
  • scale
                 Scale factor for scaled beam propagation.
    

Returns:

  • kernel ( tensor ) –

    Complex kernel for the given propagation type.

Source code in odak/learn/wave/classical.py
def get_propagation_kernel(
                           nu, 
                           nv, 
                           dx = 8e-6, 
                           wavelength = 515e-9, 
                           distance = 0., 
                           device = torch.device('cpu'), 
                           propagation_type = 'Bandlimited Angular Spectrum', 
                           scale = 1
                          ):
    """
    Get propagation kernel for the propagation type.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    propagation_type   : str
                         Propagation type.
                         The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    scale              : int
                         Scale factor for scaled beam propagation.


    Returns
    -------
    kernel             : torch.tensor
                         Complex kernel for the given propagation type.
    """
    if propagation_type == 'Bandlimited Angular Spectrum':
        kernel = get_band_limited_angular_spectrum_kernel(nu, nv, dx, wavelength, distance, device)
    elif propagation_type == 'Angular Spectrum':
        kernel = get_angular_spectrum_kernel(nu, nv, dx, wavelength, distance, device)
    elif propagation_type == 'Transfer Function Fresnel':
        kernel = get_transfer_function_fresnel_kernel(nu, nv, dx, wavelength, distance, device)
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    return kernel

get_transfer_function_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.transfer_function_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
    """
    Helper function for odak.learn.wave.transfer_function_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing = 'ij')
    k = wavenumber(wavelength)
    H = torch.exp(1j * distance * (k - torch.pi * wavelength * (FX ** 2 + FY ** 2)))
    return H

point_wise(target, wavelength, distance, dx, device, lens_size=401)

Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

Parameters:

  • target
               float input target to be converted into a hologram (Target should be in range of 0 and 1).
    
  • wavelength
               Wavelength of the electric field.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • device
               Device type (cuda or cpu)`.
    
  • lens_size
               Size of lens for masking sub holograms(in pixels).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

Source code in odak/learn/wave/classical.py
def point_wise(target, wavelength, distance, dx, device, lens_size=401):
    """
    Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

    Parameters
    ----------
    target           : torch.float
                       float input target to be converted into a hologram (Target should be in range of 0 and 1).
    wavelength       : float
                       Wavelength of the electric field.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    device           : torch.device
                       Device type (cuda or cpu)`.
    lens_size        : int
                       Size of lens for masking sub holograms(in pixels).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    """
    target = zero_pad(target)
    nx, ny = target.shape
    k = wavenumber(wavelength)
    ones = torch.ones(target.shape, requires_grad=False).to(device)
    x = torch.linspace(-nx/2, nx/2, nx).to(device)
    y = torch.linspace(-ny/2, ny/2, ny).to(device)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = (X**2+Y**2)**0.5
    mask = (torch.abs(Z) <= lens_size)
    mask[mask > 1] = 1
    fz = quadratic_phase_function(nx, ny, k, focal=-distance, dx=dx).to(device)
    A = torch.nan_to_num(target**0.5, nan=0.0)
    fz = mask*fz
    FA = torch.fft.fft2(torch.fft.fftshift(A))
    FFZ = torch.fft.fft2(torch.fft.fftshift(fz))
    H = torch.mul(FA, FFZ)
    hologram = torch.fft.ifftshift(torch.fft.ifft2(H))
    hologram = crop_center(hologram)
    return hologram

propagate_beam(field, k, distance, dx, wavelength, propagation_type='Bandlimited Angular Spectrum', kernel=None, zero_padding=[True, False, True], aperture=1.0)

Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • propagation_type (str, default: 'Bandlimited Angular Spectrum' ) –
               Type of the propagation.
               The options are Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    
  • kernel
               Custom complex kernel.
    
  • zero_padding
               Zero padding the input field if the first item in the list set True.
               Zero padding in the Fourier domain if the second item in the list set to True.
               Cropping the result with half resolution if the third item in the list is set to true. 
               Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def propagate_beam(
                   field, 
                   k, 
                   distance, 
                   dx, 
                   wavelength, 
                   propagation_type='Bandlimited Angular Spectrum', 
                   kernel = None, 
                   zero_padding = [True, False, True],
                   aperture = 1.
                  ):
    """
    Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    propagation_type : str
                       Type of the propagation.
                       The options are Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    kernel           : torch.complex
                       Custom complex kernel.
    zero_padding     : list
                       Zero padding the input field if the first item in the list set True.
                       Zero padding in the Fourier domain if the second item in the list set to True.
                       Cropping the result with half resolution if the third item in the list is set to true. 
                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.

    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    if zero_padding[0]:
        field = zero_pad(field)
    if propagation_type == 'Angular Spectrum':
        result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Bandlimited Angular Spectrum':
        result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Transfer Function Fresnel':
        result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
    elif propagation_type == 'custom':
        result = custom(field, kernel, zero_padding[1], aperture = aperture)
    elif propagation_type == 'Fraunhofer':
        result = fraunhofer(field, k, distance, dx, wavelength)
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    if zero_padding[2]:
        result = crop_center(result)
    return result

shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None)

Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in here and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

Parameters:

  • phase
               Phase value of a phase-only hologram.
    
  • depth_shift
               Distance in meters.
    
  • pixel_pitch
               Pixel pitch size in meters.
    
  • wavelength
               Wavelength of light.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Beam propagation type. For more see odak.learn.wave.propagate_beam().
    
  • kernel_length
               Kernel length for the Gaussian blur kernel.
    
  • sigma
               Standard deviation for the Gaussian blur kernel.
    
  • amplitude
               Amplitude value of a complex hologram.
    
Source code in odak/learn/wave/classical.py
def shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None):
    """
    Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

    Parameters
    ----------
    phase            : torch.tensor
                       Phase value of a phase-only hologram.
    depth_shift      : float
                       Distance in meters.
    pixel_pitch      : float
                       Pixel pitch size in meters.
    wavelength       : float
                       Wavelength of light.
    propagation_type : str
                       Beam propagation type. For more see odak.learn.wave.propagate_beam().
    kernel_length    : int
                       Kernel length for the Gaussian blur kernel.
    sigma            : float
                       Standard deviation for the Gaussian blur kernel.
    amplitude        : torch.tensor
                       Amplitude value of a complex hologram.
    """
    if type(amplitude) == type(None):
        amplitude = torch.ones_like(phase)
    hologram = generate_complex_field(amplitude, phase)
    k = wavenumber(wavelength)
    hologram_padded = zero_pad(hologram)
    shifted_field_padded = propagate_beam(
                                          hologram_padded,
                                          k,
                                          depth_shift,
                                          pixel_pitch,
                                          wavelength,
                                          propagation_type
                                         )
    shifted_field = crop_center(shifted_field_padded)
    phase_shift = torch.exp(torch.tensor([-2 * np.pi * depth_shift / wavelength]).to(phase.device))
    shift = torch.cos(phase_shift) + 1j * torch.sin(phase_shift)
    shifted_complex_hologram = shifted_field * shift

    if kernel_length > 0 and sigma >0:
        blur_kernel = generate_2d_gaussian(
                                           [kernel_length, kernel_length],
                                           [sigma, sigma]
                                          ).to(phase.device)
        blur_kernel = blur_kernel.unsqueeze(0)
        blur_kernel = blur_kernel.unsqueeze(0)
        field_imag = torch.imag(shifted_complex_hologram)
        field_real = torch.real(shifted_complex_hologram)
        field_imag = field_imag.unsqueeze(0)
        field_imag = field_imag.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_imag = torch.nn.functional.conv2d(field_imag, blur_kernel, padding='same')
        field_real = torch.nn.functional.conv2d(field_real, blur_kernel, padding='same')
        shifted_complex_hologram = torch.complex(field_real, field_imag)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)

    shifted_amplitude = calculate_amplitude(shifted_complex_hologram)
    shifted_amplitude = shifted_amplitude / torch.amax(shifted_amplitude, [0,1])

    shifted_phase = calculate_phase(shifted_complex_hologram)
    phase_zero_mean = shifted_phase - torch.mean(shifted_phase)

    phase_offset = torch.arccos(shifted_amplitude)
    phase_low = phase_zero_mean - phase_offset
    phase_high = phase_zero_mean + phase_offset

    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only

stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type='Bandlimited Angular Spectrum', n_iteration=100, loss_function=None, learning_rate=0.1)

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

  • target
                        Target field amplitude [m x n].
                        Keep the target values between zero and one.
    
  • wavelength
                        Set if the converted array requires gradient.
    
  • distance
                        Hologram plane distance wrt SLM plane.
    
  • pixel_pitch
                        SLM pixel pitch in meters.
    
  • propagation_type
                        Type of the propagation (see odak.learn.wave.propagate_beam()).
    
  • n_iteration
                        Number of iteration.
    
  • loss_function
                        If none it is set to be l2 loss.
    
  • learning_rate
                        Learning rate.
    

Returns:

  • hologram ( Tensor ) –

    Phase only hologram as torch array

  • reconstruction_intensity ( Tensor ) –

    Reconstruction as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):
    """
    Definition to generate phase and reconstruction from target image via stochastic gradient descent.

    Parameters
    ----------
    target                    : torch.Tensor
                                Target field amplitude [m x n].
                                Keep the target values between zero and one.
    wavelength                : double
                                Set if the converted array requires gradient.
    distance                  : double
                                Hologram plane distance wrt SLM plane.
    pixel_pitch               : float
                                SLM pixel pitch in meters.
    propagation_type          : str
                                Type of the propagation (see odak.learn.wave.propagate_beam()).
    n_iteration:              : int
                                Number of iteration.
    loss_function:            : function
                                If none it is set to be l2 loss.
    learning_rate             : float
                                Learning rate.

    Returns
    -------
    hologram                  : torch.Tensor
                                Phase only hologram as torch array

    reconstruction_intensity  : torch.Tensor
                                Reconstruction as torch array

    """
    phase = torch.randn_like(target, requires_grad = True)
    k = wavenumber(wavelength)
    optimizer = torch.optim.Adam([phase], lr = learning_rate)
    if type(loss_function) == type(None):
        loss_function = torch.nn.MSELoss()
    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)
    for i in t:
        optimizer.zero_grad()
        hologram = generate_complex_field(1., phase)
        reconstruction = propagate_beam(
                                        hologram, 
                                        k, 
                                        distance, 
                                        pixel_pitch, 
                                        wavelength, 
                                        propagation_type, 
                                        zero_padding = [True, False, True]
                                       )
        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
        loss = loss_function(reconstruction_intensity, target)
        description = "Loss:{:.4f}".format(loss.item())
        loss.backward(retain_graph = True)
        optimizer.step()
        t.set_description(description)
    print(description)
    torch.no_grad()
    hologram = generate_complex_field(1., phase)
    reconstruction = propagate_beam(
                                    hologram, 
                                    k, 
                                    distance, 
                                    pixel_pitch, 
                                    wavelength, 
                                    propagation_type, 
                                    zero_padding = [True, False, True]
                                   )
    return hologram, reconstruction

transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_transfer_function_fresnel_kernel(
                                             field.shape[-2], 
                                             field.shape[-1], 
                                             dx = dx, 
                                             wavelength = wavelength, 
                                             distance = distance, 
                                             device = field.device
                                            )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

blazed_grating(nx, ny, levels=2, axis='x')

A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • levels
           Number of pixels.
    
  • axis
           Axis of glazed grating. It could be `x` or `y`.
    
Source code in odak/learn/wave/lens.py
def blazed_grating(nx, ny, levels = 2, axis = 'x'):
    """
    A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.


    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    levels       : int
                   Number of pixels.
    axis         : str
                   Axis of glazed grating. It could be `x` or `y`.

    """
    if levels < 2:
        levels = 2
    x = (torch.abs(torch.arange(-nx, 0)) % levels) / levels * (2 * np.pi)
    y = (torch.abs(torch.arange(-ny, 0)) % levels) / levels * (2 * np.pi)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'x':
        blazed_grating = torch.exp(1j * X)
    elif axis == 'y':
        blazed_grating = torch.exp(1j * Y)
    return blazed_grating

linear_grating(nx, ny, every=2, add=None, axis='x')

A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • every
         Add the add value at every given number.
    
  • add
         Angle to be added.
    
  • axis
         Axis eiter X,Y or both.
    

Returns:

  • field ( tensor ) –

    Linear grating term.

Source code in odak/learn/wave/lens.py
def linear_grating(nx, ny, every = 2, add = None, axis = 'x'):
    """
    A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    every      : int
                 Add the add value at every given number.
    add        : float
                 Angle to be added.
    axis       : string
                 Axis eiter X,Y or both.

    Returns
    ----------
    field      : torch.tensor
                 Linear grating term.
    """
    if isinstance(add, type(None)):
        add = np.pi
    grating = torch.zeros((nx, ny), dtype=torch.complex64)
    if axis == 'x':
        grating[::every, :] = torch.exp(torch.tensor(1j*add))
    if axis == 'y':
        grating[:, ::every] = torch.exp(torch.tensor(1j*add))
    if axis == 'xy':
        checker = np.indices((nx, ny)).sum(axis=0) % every
        checker = torch.from_numpy(checker)
        checker += 1
        checker = checker % 2
        grating = torch.exp(1j*checker*add)
    return grating

prism_grating(nx, ny, k, angle, dx=0.001, axis='x', phase_offset=0.0)

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • k
           See odak.wave.wavenumber for more.
    
  • angle
           Tilt angle of the prism in degrees.
    
  • dx
           Pixel pitch.
    
  • axis
           Axis of the prism.
    
  • phase_offset (float, default: 0.0 ) –
           Phase offset in angles. Default is zero.
    

Returns:

  • prism ( tensor ) –

    Generated phase function for a prism.

Source code in odak/learn/wave/lens.py
def prism_grating(nx, ny, k, angle, dx = 0.001, axis = 'x', phase_offset = 0.):
    """
    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    k            : odak.wave.wavenumber
                   See odak.wave.wavenumber for more.
    angle        : float
                   Tilt angle of the prism in degrees.
    dx           : float
                   Pixel pitch.
    axis         : str
                   Axis of the prism.
    phase_offset : float
                   Phase offset in angles. Default is zero.

    Returns
    ----------
    prism        : torch.tensor
                   Generated phase function for a prism.
    """
    angle = torch.deg2rad(torch.tensor([angle]))
    phase_offset = torch.deg2rad(torch.tensor([phase_offset]))
    x = torch.arange(0, nx) * dx
    y = torch.arange(0, ny) * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'y':
        phase = k * torch.sin(angle) * Y + phase_offset
        prism = torch.exp(-1j * phase)
    elif axis == 'x':
        phase = k * torch.sin(angle) * X + phase_offset
        prism = torch.exp(-1j * phase)
    return prism

quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • k
         See odak.wave.wavenumber for more.
    
  • focal
         Focal length of the quadratic phase function.
    
  • dx
         Pixel pitch.
    
  • offset
         Deviation from the center along X and Y axes.
    

Returns:

  • function ( tensor ) –

    Generated quadratic phase function.

Source code in odak/learn/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
    """ 
    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    k          : odak.wave.wavenumber
                 See odak.wave.wavenumber for more.
    focal      : float
                 Focal length of the quadratic phase function.
    dx         : float
                 Pixel pitch.
    offset     : list
                 Deviation from the center along X and Y axes.

    Returns
    -------
    function   : torch.tensor
                 Generated quadratic phase function.
    """
    size = [nx, ny]
    x = torch.linspace(-size[0] * dx / 2, size[0] * dx / 2, size[0]) - offset[1] * dx
    y = torch.linspace(-size[1] * dx / 2, size[1] * dx / 2, size[1]) - offset[0] * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = X**2 + Y**2
    qwf = torch.exp(-0.5j * k / focal * Z)
    return qwf

multiplane_loss

Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class multiplane_loss():
    """
    Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """


    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6, 0.], 
                 multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image      : torch.tensor
                            Color target image [3 x m x n].
        target_depth      : torch.tensor
                            Monochrome target depth, same resolution as target_image.
        target_blur_size  : int
                            Maximum target blur size.
        blur_ratio        : float
                            Blur ratio, a value between zero and one.
        number_of_planes  : int
                            Number of planes.
        weights           : list
                            Weights of the loss function.
        multiplier        : float
                            Multiplier to multipy with targets.
        scheme            : str
                            The type of the loss, `naive` without defocus or `defocus` with defocus.
        reduction         : str
                            Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        device            : torch.device
                            Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.weights          = weights
        self.reduction        = reduction
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.loss_function = torch.nn.MSELoss(reduction = self.reduction)


    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
                sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.

        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        l2 = self.weights[0] * self.loss_function(image, target)
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
        l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
        loss = l2 + l2_mask + l2_cor
        return loss

__call__(image, target, plane_id=None)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.

    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    l2 = self.weights[0] * self.loss_function(image, target)
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
    l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
    loss = l2 + l2_mask + l2_cor
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, weights=[1.0, 2.1, 0.6, 0.0], multiplier=1.0, scheme='defocus', reduction='mean', device=torch.device('cpu'))

Parameters:

  • target_image
                Color target image [3 x m x n].
    
  • target_depth
                Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                Maximum target blur size.
    
  • blur_ratio
                Blur ratio, a value between zero and one.
    
  • number_of_planes
                Number of planes.
    
  • weights
                Weights of the loss function.
    
  • multiplier
                Multiplier to multipy with targets.
    
  • scheme
                The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • reduction
                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • device
                Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6, 0.], 
             multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image      : torch.tensor
                        Color target image [3 x m x n].
    target_depth      : torch.tensor
                        Monochrome target depth, same resolution as target_image.
    target_blur_size  : int
                        Maximum target blur size.
    blur_ratio        : float
                        Blur ratio, a value between zero and one.
    number_of_planes  : int
                        Number of planes.
    weights           : list
                        Weights of the loss function.
    multiplier        : float
                        Multiplier to multipy with targets.
    scheme            : str
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    reduction         : str
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    device            : torch.device
                        Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.weights          = weights
    self.reduction        = reduction
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            sigmas = torch.linspace(start = 0, end = self.target_blur_size, steps = self.number_of_planes)
            sigmas = sigmas - i * self.target_blur_size / (self.number_of_planes - 1 + 1e-10)
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

phase_gradient

Bases: Module

The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

Source code in odak/learn/wave/loss.py
class phase_gradient(nn.Module):

    """
    The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude. 

    This implements a convolution of the phase with a kernel.

    The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.
    """


    def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
        """
        Parameters
        ----------
        kernel                  : torch.tensor
                                    Convolution filter kernel, 3 by 3 Laplacian kernel by default.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(phase_gradient, self).__init__()
        self.device = device
        self.loss = loss
        if kernel == None:
            self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
        else:
            if len(kernel.shape) == 4:
                self.kernel = kernel
            else:
                self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
        self.kernel = Variable(self.kernel.to(self.device))


    def forward(self, phase):
        """
        Calculates the phase gradient Loss.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(phase.shape) == 2:
            phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
        edge_detect = self.functional_conv2d(phase)
        loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
        return loss_value


    def functional_conv2d(self, phase):
        """
        Calculates the gradient of the phase.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        edge_detect              : torch.tensor
                                    The computed phase gradient.
        """
        edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
        return edge_detect

__init__(kernel=None, loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel
                        Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
    """
    Parameters
    ----------
    kernel                  : torch.tensor
                                Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(phase_gradient, self).__init__()
    self.device = device
    self.loss = loss
    if kernel == None:
        self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
    else:
        if len(kernel.shape) == 4:
            self.kernel = kernel
        else:
            self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
    self.kernel = Variable(self.kernel.to(self.device))

forward(phase)

Calculates the phase gradient Loss.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, phase):
    """
    Calculates the phase gradient Loss.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(phase.shape) == 2:
        phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
    edge_detect = self.functional_conv2d(phase)
    loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
    return loss_value

functional_conv2d(phase)

Calculates the gradient of the phase.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • edge_detect ( tensor ) –

    The computed phase gradient.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, phase):
    """
    Calculates the gradient of the phase.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    edge_detect              : torch.tensor
                                The computed phase gradient.
    """
    edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
    return edge_detect

speckle_contrast

Bases: Module

The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

We refer to the following paper:

Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0.

Source code in odak/learn/wave/loss.py
class speckle_contrast(nn.Module):

    """
    The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

    We refer to the following paper:

    Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0. 
    """


    def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
        """
        Parameters
        ----------
        kernel_size             : torch.tensor
                                    Convolution filter kernel size, 11 by 11 average kernel by default.
        step_size               : tuple
                                    Convolution stride in height and width direction.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(speckle_contrast, self).__init__()
        self.device = device
        self.loss = loss
        self.step_size = step_size
        self.kernel_size = kernel_size
        self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
        self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))


    def forward(self, intensity):
        """
        Calculates the speckle contrast Loss.

        Parameters
        ----------
        intensity               : torch.tensor
                                    intensity of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(intensity.shape) == 2:
            intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
        Speckle_C = self.functional_conv2d(intensity)
        loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
        return loss_value


    def functional_conv2d(self, intensity):
        """
        Calculates the speckle contrast of the intensity.

        Parameters
        ----------
        intensity                : torch.tensor
                                    Intensity of the complex field.

        Returns
        -------

        Speckle_C               : torch.tensor
                                    The computed speckle contrast.
        """
        mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
        var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
        Speckle_C = var / mean
        return Speckle_C

__init__(kernel_size=11, step_size=(1, 1), loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel_size
                        Convolution filter kernel size, 11 by 11 average kernel by default.
    
  • step_size
                        Convolution stride in height and width direction.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
    """
    Parameters
    ----------
    kernel_size             : torch.tensor
                                Convolution filter kernel size, 11 by 11 average kernel by default.
    step_size               : tuple
                                Convolution stride in height and width direction.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(speckle_contrast, self).__init__()
    self.device = device
    self.loss = loss
    self.step_size = step_size
    self.kernel_size = kernel_size
    self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
    self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))

forward(intensity)

Calculates the speckle contrast Loss.

Parameters:

  • intensity
                        intensity of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, intensity):
    """
    Calculates the speckle contrast Loss.

    Parameters
    ----------
    intensity               : torch.tensor
                                intensity of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(intensity.shape) == 2:
        intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
    Speckle_C = self.functional_conv2d(intensity)
    loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
    return loss_value

functional_conv2d(intensity)

Calculates the speckle contrast of the intensity.

Parameters:

  • intensity
                        Intensity of the complex field.
    

Returns:

  • Speckle_C ( tensor ) –

    The computed speckle contrast.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, intensity):
    """
    Calculates the speckle contrast of the intensity.

    Parameters
    ----------
    intensity                : torch.tensor
                                Intensity of the complex field.

    Returns
    -------

    Speckle_C               : torch.tensor
                                The computed speckle contrast.
    """
    mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
    var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
    Speckle_C = var / mean
    return Speckle_C

holobeam_multiholo

Bases: Module

The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.

Parameters:

  • n_input
                Number of channels in the input.
    
  • n_hidden
                Number of channels in the hidden layers.
    
  • n_output
                Number of channels in the output layer.
    
  • device
                Default device is CPU.
    
  • reduction
                Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    
Source code in odak/learn/wave/models.py
class holobeam_multiholo(torch.nn.Module):
    """
    The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.


    Parameters
    ----------
    n_input           : int
                        Number of channels in the input.
    n_hidden          : int
                        Number of channels in the hidden layers.
    n_output          : int
                        Number of channels in the output layer.
    device            : torch.device
                        Default device is CPU.
    reduction         : str
                        Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    """
    def __init__(
                 self,
                 n_input = 1,
                 n_hidden = 16,
                 n_output = 2,
                 device = torch.device('cpu'),
                 reduction = 'sum'
                ):
        super(holobeam_multiholo, self).__init__()
        torch.random.seed()
        self.device = device
        self.reduction = reduction
        self.l2 = torch.nn.MSELoss(reduction = self.reduction)
        self.l1 = torch.nn.L1Loss(reduction = self.reduction)
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_output = n_output
        self.network = unet(
                            dimensions = self.n_hidden,
                            input_channels = self.n_input,
                            output_channels = self.n_output
                           ).to(self.device)


    def forward(self, x, test = False):
        """
        Internal function representing the forward model.
        """
        if test:
            torch.no_grad()
        y = self.network.forward(x) 
        phase_low = y[:, 0].unsqueeze(1)
        phase_high = y[:, 1].unsqueeze(1)
        phase_only = torch.zeros_like(phase_low)
        phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
        phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
        phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
        phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
        return phase_only


    def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
        """
        Internal function for evaluating.
        """
        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
        return loss


    def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
        """
        Function to train the weights of the multi layer perceptron.

        Parameters
        ----------
        dataloader       : torch.utils.data.DataLoader
                           Data loader.
        number_of_epochs : int
                           Number of epochs.
        learning_rate    : float
                           Learning rate of the optimizer.
        directory        : str
                           Output directory.
        save_at_every    : int
                           Save the model at every given epoch count.
        """
        t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        for i in t_epoch:
            epoch_loss = 0.
            t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
            for j, data in enumerate(t_data):
                self.optimizer.zero_grad()
                images, holograms = data
                estimates = self.forward(images)
                loss = self.evaluate(estimates, holograms)
                loss.backward(retain_graph=True)
                self.optimizer.step()
                description = 'Loss:{:.4f}'.format(loss.item())
                t_data.set_description(description)
                epoch_loss += float(loss.item()) / dataloader.__len__()
            description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
            t_epoch.set_description(description)
            if i % save_at_every == 0:
                self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
        self.save_weights(filename='{}/weights.pt'.format(directory))
        print(description)


    def save_weights(self, filename = './weights.pt'):
        """
        Function to save the current weights of the multi layer perceptron to a file.
        Parameters
        ----------
        filename        : str
                          Filename.
        """
        torch.save(self.network.state_dict(), os.path.expanduser(filename))


    def load_weights(self, filename = './weights.pt'):
        """
        Function to load weights for this multi layer perceptron from a file.
        Parameters
        ----------
        filename        : str
                          Filename.
        """
        self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
        self.network.eval()

evaluate(input_data, ground_truth, weights=[1.0, 0.1])

Internal function for evaluating.

Source code in odak/learn/wave/models.py
def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
    """
    Internal function for evaluating.
    """
    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
    return loss

fit(dataloader, number_of_epochs=100, learning_rate=1e-05, directory='./output', save_at_every=100)

Function to train the weights of the multi layer perceptron.

Parameters:

  • dataloader
               Data loader.
    
  • number_of_epochs (int, default: 100 ) –
               Number of epochs.
    
  • learning_rate
               Learning rate of the optimizer.
    
  • directory
               Output directory.
    
  • save_at_every
               Save the model at every given epoch count.
    
Source code in odak/learn/wave/models.py
def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
    """
    Function to train the weights of the multi layer perceptron.

    Parameters
    ----------
    dataloader       : torch.utils.data.DataLoader
                       Data loader.
    number_of_epochs : int
                       Number of epochs.
    learning_rate    : float
                       Learning rate of the optimizer.
    directory        : str
                       Output directory.
    save_at_every    : int
                       Save the model at every given epoch count.
    """
    t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
    for i in t_epoch:
        epoch_loss = 0.
        t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
        for j, data in enumerate(t_data):
            self.optimizer.zero_grad()
            images, holograms = data
            estimates = self.forward(images)
            loss = self.evaluate(estimates, holograms)
            loss.backward(retain_graph=True)
            self.optimizer.step()
            description = 'Loss:{:.4f}'.format(loss.item())
            t_data.set_description(description)
            epoch_loss += float(loss.item()) / dataloader.__len__()
        description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
        t_epoch.set_description(description)
        if i % save_at_every == 0:
            self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
    self.save_weights(filename='{}/weights.pt'.format(directory))
    print(description)

forward(x, test=False)

Internal function representing the forward model.

Source code in odak/learn/wave/models.py
def forward(self, x, test = False):
    """
    Internal function representing the forward model.
    """
    if test:
        torch.no_grad()
    y = self.network.forward(x) 
    phase_low = y[:, 0].unsqueeze(1)
    phase_high = y[:, 1].unsqueeze(1)
    phase_only = torch.zeros_like(phase_low)
    phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
    phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
    phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
    phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
    return phase_only

load_weights(filename='./weights.pt')

Function to load weights for this multi layer perceptron from a file.

Parameters:

  • filename
              Filename.
    
Source code in odak/learn/wave/models.py
def load_weights(self, filename = './weights.pt'):
    """
    Function to load weights for this multi layer perceptron from a file.
    Parameters
    ----------
    filename        : str
                      Filename.
    """
    self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
    self.network.eval()

save_weights(filename='./weights.pt')

Function to save the current weights of the multi layer perceptron to a file.

Parameters:

  • filename
              Filename.
    
Source code in odak/learn/wave/models.py
def save_weights(self, filename = './weights.pt'):
    """
    Function to save the current weights of the multi layer perceptron to a file.
    Parameters
    ----------
    filename        : str
                      Filename.
    """
    torch.save(self.network.state_dict(), os.path.expanduser(filename))

multi_color_hologram_optimizer

A class for optimizing single or multi color holograms. For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.

Source code in odak/learn/wave/optimizers.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class multi_color_hologram_optimizer():
    """
    A class for optimizing single or multi color holograms.
    For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.
    """
    def __init__(self,
                 wavelengths,
                 resolution,
                 targets,
                 propagator,
                 number_of_frames = 3,
                 number_of_depth_layers = 1,
                 learning_rate = 2e-2,
                 learning_rate_floor = 5e-3,
                 double_phase = True,
                 scale_factor = 1,
                 method = 'multi-color',
                 channel_power_filename = '',
                 device = None,
                 loss_function = None,
                 peak_amplitude = 1.0,
                 optimize_peak_amplitude = False,
                 img_loss_thres = 2e-3,
                 reduction = 'sum'
                ):
        self.device = device
        if isinstance(self.device, type(None)):
            self.device = torch.device("cpu")
        torch.cuda.empty_cache()
        torch.random.seed()
        self.wavelengths = wavelengths
        self.resolution = resolution
        self.targets = targets
        self.scale_factor = scale_factor
        self.propagator = propagator
        self.learning_rate = learning_rate
        self.learning_rate_floor = learning_rate_floor
        self.number_of_channels = len(self.wavelengths)
        self.number_of_frames = number_of_frames
        self.number_of_depth_layers = number_of_depth_layers
        self.double_phase = double_phase
        self.channel_power_filename = channel_power_filename
        self.method = method
        self.upsample = torch.nn.Upsample(scale_factor = self.scale_factor, mode = 'nearest')
        if self.method != 'conventional' and self.method != 'multi-color':
           logging.warning('Unknown optimization method. Options are conventional or multi-color.')
           import sys
           sys.exit()
        self.peak_amplitude = peak_amplitude
        self.optimize_peak_amplitude = optimize_peak_amplitude
        if self.optimize_peak_amplitude:
            self.init_peak_amplitude_scale()
        self.img_loss_thres = img_loss_thres
        self.kernels = []
        self.init_phase()
        self.init_channel_power()
        self.init_loss_function(loss_function, reduction = reduction)
        self.init_amplitude()
        self.init_phase_scale()


    def init_peak_amplitude_scale(self):
        """
        Internal function to set the phase scale.
        """
        self.peak_amplitude = torch.tensor(
                                           self.peak_amplitude,
                                           requires_grad = True,
                                           device=self.device
                                          )


    def init_phase_scale(self):
        """
        Internal function to set the phase scale.
        """
        if self.method == 'conventional':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )
        if self.method == 'multi-color':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )


    def init_amplitude(self):
        """
        Internal function to set the amplitude of the illumination source.
        """
        self.amplitude = torch.ones(
                                    self.resolution[0],
                                    self.resolution[1],
                                    requires_grad = False,
                                    device = self.device
                                   )


    def init_phase(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        self.phase = torch.zeros(
                                 self.number_of_frames,
                                 self.resolution[0],
                                 self.resolution[1],
                                 device = self.device,
                                 requires_grad = True
                                )
        self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)


    def init_channel_power(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        if self.method == 'conventional':
            logging.warning('Scheme: Conventional')
            self.channel_power = torch.eye(
                                           self.number_of_frames,
                                           self.number_of_channels,
                                           device = self.device,
                                           requires_grad = False
                                          )

        elif self.method == 'multi-color':
            logging.warning('Scheme: Multi-color')
            self.channel_power = torch.ones(
                                            self.number_of_frames,
                                            self.number_of_channels,
                                            device = self.device,
                                            requires_grad = True
                                           )
        if self.channel_power_filename != '':
            self.channel_power = torch_load(self.channel_power_filename).to(self.device)
            self.channel_power.requires_grad = False
            self.channel_power[self.channel_power < 0.] = 0.
            self.channel_power[self.channel_power > 1.] = 1.
            if self.method == 'multi-color':
                self.channel_power.requires_grad = True
            if self.method == 'conventional':
                self.channel_power = torch.abs(torch.cos(self.channel_power))
            logging.warning('Channel powers:')
            logging.warning(self.channel_power)
            logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
        self.propagator.set_laser_powers(self.channel_power)



    def init_optimizer(self):
        """
        Internal function to set the optimizer.
        """
        optimization_variables = [self.phase, self.offset]
        if self.optimize_peak_amplitude:
            optimization_variables.append(self.peak_amplitude)
        if self.method == 'multi-color':
            optimization_variables.append(self.propagator.channel_power)
        self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)


    def init_loss_function(self, loss_function, reduction = 'sum'):
        """
        Internal function to set the loss function.
        """
        self.l2_loss = torch.nn.MSELoss(reduction = reduction)
        self.loss_type = 'custom'
        self.loss_function = loss_function
        if isinstance(self.loss_function, type(None)):
            self.loss_type = 'conventional'
            self.loss_function = torch.nn.MSELoss(reduction = reduction)



    def evaluate(self, input_image, target_image, plane_id = 0):
        """
        Internal function to evaluate the loss.
        """
        if self.loss_type == 'conventional':
            loss = self.loss_function(input_image, target_image)
        elif self.loss_type == 'custom':
            loss = 0
            for i in range(len(self.wavelengths)):
                loss += self.loss_function(
                                           input_image[i],
                                           target_image[i],
                                           plane_id = plane_id
                                          )
        return loss


    def double_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase similarly to double phase encoding.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_zero_mean = phase - torch.mean(phase)
        phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)
        phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)
        loss = multi_scale_total_variation_loss(phase_low, levels = 6)
        loss += multi_scale_total_variation_loss(phase_high, levels = 6)
        loss += torch.std(phase_low)
        loss += torch.std(phase_high)
        phase_only = torch.zeros_like(phase)
        phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
        phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
        phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
        phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
        return phase_only, loss


    def direct_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)
        loss = multi_scale_total_variation_loss(phase, levels = 6)
        loss += multi_scale_total_variation_loss(phase_offset, levels = 6)
        return phase_only, loss


    def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
        """
        Function to optimize multiplane phase-only holograms using stochastic gradient descent.

        Parameters
        ----------
        number_of_iterations       : float
                                     Number of iterations.
        weights                    : list
                                     Weights used in the loss function.

        Returns
        -------
        hologram                   : torch.tensor
                                     Optimised hologram.
        """
        hologram_phases = torch.zeros(
                                      self.number_of_frames,
                                      self.resolution[0],
                                      self.resolution[1],
                                      device = self.device
                                     )
        t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
        if self.optimize_peak_amplitude:
            peak_amp_cache = self.peak_amplitude.item()
        for step in t:
            for g in self.optimizer.param_groups:
                g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
                if g['lr'] < self.learning_rate_floor:
                    g['lr'] = self.learning_rate_floor
                learning_rate = g['lr']
            total_loss = 0
            t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
            for depth_id in t_depth:
                self.optimizer.zero_grad()
                depth_target = self.targets[depth_id]
                reconstruction_intensities = torch.zeros(
                                                         self.number_of_frames,
                                                         self.number_of_channels,
                                                         self.resolution[0] * self.scale_factor,
                                                         self.resolution[1] * self.scale_factor,
                                                         device = self.device
                                                        )
                loss_variation_hologram = 0
                laser_powers = self.propagator.get_laser_powers()
                for frame_id in range(self.number_of_frames):
                    if self.double_phase:
                        phase, loss_phase = self.double_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    else:
                        phase, loss_phase = self.direct_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    loss_variation_hologram += loss_phase
                    for channel_id in range(self.number_of_channels):
                        if self.scale_factor != 1:
                            phase_scaled = torch.zeros_like(self.amplitude)
                            phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
                            self.amplitude[1::self.scale_factor, 1::self.scale_factor] = 0.
                        else:
                            phase_scaled = phase
                        laser_power = laser_powers[frame_id][channel_id]
                        hologram = generate_complex_field(
                                                          laser_power * self.amplitude,
                                                          phase_scaled * self.phase_scale[channel_id]
                                                         )
                        reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                        intensity = calculate_amplitude(reconstruction_field) ** 2
                        reconstruction_intensities[frame_id, channel_id] += intensity
                    hologram_phases[frame_id] = phase.detach().clone()
                loss_laser = self.l2_loss(
                                          torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
                                          torch.sum(laser_powers, dim = 0)
                                         )
                loss_laser += self.l2_loss(
                                           torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
                                           torch.sum(laser_powers).view(1,)
                                          )
                loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
                reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
                loss_image = self.evaluate(
                                           reconstruction_intensity,
                                           depth_target * self.peak_amplitude,
                                           plane_id = depth_id
                                          )
                loss = weights[0] * loss_image
                loss += weights[1] * loss_laser
                loss += weights[2] * loss_variation_hologram
                include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
                if include_pa_loss_flag:
                    loss -= self.peak_amplitude * 1.
                if self.method == 'conventional':
                    loss.backward()
                else:
                    loss.backward(retain_graph = True)
                self.optimizer.step()
                if include_pa_loss_flag:
                    peak_amp_cache = self.peak_amplitude.item()
                else:
                    with torch.no_grad():
                        if self.optimize_peak_amplitude:
                            self.peak_amplitude.view([1])[0] = peak_amp_cache
                total_loss += loss.detach().item()
                loss_image = loss_image.detach()
                del loss_laser
                del loss_variation_hologram
                del loss
            description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
            t.set_description(description)
            del total_loss
            del loss_image
            del reconstruction_field
            del reconstruction_intensities
            del intensity
            del phase
            del hologram
        logging.warning(description)
        return hologram_phases.detach()


    def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):
        """
        Function to optimize multiplane phase-only holograms.

        Parameters
        ----------
        number_of_iterations       : int
                                     Number of iterations.
        weights                    : list
                                     Loss weights.
        bits                       : int
                                     Quantizes the hologram using the given bits and reconstructs.

        Returns
        -------
        hologram_phases            : torch.tensor
                                     Phases of the optimized phase-only hologram.
        reconstruction_intensities : torch.tensor
                                     Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
        """
        self.init_optimizer()
        hologram_phases = self.gradient_descent(
                                                number_of_iterations=number_of_iterations,
                                                weights=weights
                                               )
        hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi
        torch.no_grad()
        reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
        laser_powers = self.propagator.get_laser_powers()
        channel_powers = self.propagator.channel_power
        logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
        logging.warning('Laser powers: {}'.format(laser_powers))
        return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)

direct_phase_constrain(phase, phase_offset)

Internal function to constrain a given phase.

Parameters:

  • phase
                         Input phase values to be constrained.
    
  • phase_offset
                         Input phase offset value.
    

Returns:

  • phase_only ( tensor ) –

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def direct_phase_constrain(self, phase, phase_offset):
    """
    Internal function to constrain a given phase.

    Parameters
    ----------
    phase                      : torch.tensor
                                 Input phase values to be constrained.
    phase_offset               : torch.tensor
                                 Input phase offset value.

    Returns
    -------
    phase_only                 : torch.tensor
                                 Constrained output phase.
    """
    phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)
    loss = multi_scale_total_variation_loss(phase, levels = 6)
    loss += multi_scale_total_variation_loss(phase_offset, levels = 6)
    return phase_only, loss

double_phase_constrain(phase, phase_offset)

Internal function to constrain a given phase similarly to double phase encoding.

Parameters:

  • phase
                         Input phase values to be constrained.
    
  • phase_offset
                         Input phase offset value.
    

Returns:

  • phase_only ( tensor ) –

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def double_phase_constrain(self, phase, phase_offset):
    """
    Internal function to constrain a given phase similarly to double phase encoding.

    Parameters
    ----------
    phase                      : torch.tensor
                                 Input phase values to be constrained.
    phase_offset               : torch.tensor
                                 Input phase offset value.

    Returns
    -------
    phase_only                 : torch.tensor
                                 Constrained output phase.
    """
    phase_zero_mean = phase - torch.mean(phase)
    phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)
    phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)
    loss = multi_scale_total_variation_loss(phase_low, levels = 6)
    loss += multi_scale_total_variation_loss(phase_high, levels = 6)
    loss += torch.std(phase_low)
    loss += torch.std(phase_high)
    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only, loss

evaluate(input_image, target_image, plane_id=0)

Internal function to evaluate the loss.

Source code in odak/learn/wave/optimizers.py
def evaluate(self, input_image, target_image, plane_id = 0):
    """
    Internal function to evaluate the loss.
    """
    if self.loss_type == 'conventional':
        loss = self.loss_function(input_image, target_image)
    elif self.loss_type == 'custom':
        loss = 0
        for i in range(len(self.wavelengths)):
            loss += self.loss_function(
                                       input_image[i],
                                       target_image[i],
                                       plane_id = plane_id
                                      )
    return loss

gradient_descent(number_of_iterations=100, weights=[1.0, 1.0, 0.0, 0.0])

Function to optimize multiplane phase-only holograms using stochastic gradient descent.

Parameters:

  • number_of_iterations
                         Number of iterations.
    
  • weights
                         Weights used in the loss function.
    

Returns:

  • hologram ( tensor ) –

    Optimised hologram.

Source code in odak/learn/wave/optimizers.py
def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
    """
    Function to optimize multiplane phase-only holograms using stochastic gradient descent.

    Parameters
    ----------
    number_of_iterations       : float
                                 Number of iterations.
    weights                    : list
                                 Weights used in the loss function.

    Returns
    -------
    hologram                   : torch.tensor
                                 Optimised hologram.
    """
    hologram_phases = torch.zeros(
                                  self.number_of_frames,
                                  self.resolution[0],
                                  self.resolution[1],
                                  device = self.device
                                 )
    t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
    if self.optimize_peak_amplitude:
        peak_amp_cache = self.peak_amplitude.item()
    for step in t:
        for g in self.optimizer.param_groups:
            g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
            if g['lr'] < self.learning_rate_floor:
                g['lr'] = self.learning_rate_floor
            learning_rate = g['lr']
        total_loss = 0
        t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
        for depth_id in t_depth:
            self.optimizer.zero_grad()
            depth_target = self.targets[depth_id]
            reconstruction_intensities = torch.zeros(
                                                     self.number_of_frames,
                                                     self.number_of_channels,
                                                     self.resolution[0] * self.scale_factor,
                                                     self.resolution[1] * self.scale_factor,
                                                     device = self.device
                                                    )
            loss_variation_hologram = 0
            laser_powers = self.propagator.get_laser_powers()
            for frame_id in range(self.number_of_frames):
                if self.double_phase:
                    phase, loss_phase = self.double_phase_constrain(
                                                                    self.phase[frame_id],
                                                                    self.offset[frame_id]
                                                                   )
                else:
                    phase, loss_phase = self.direct_phase_constrain(
                                                                    self.phase[frame_id],
                                                                    self.offset[frame_id]
                                                                   )
                loss_variation_hologram += loss_phase
                for channel_id in range(self.number_of_channels):
                    if self.scale_factor != 1:
                        phase_scaled = torch.zeros_like(self.amplitude)
                        phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
                        self.amplitude[1::self.scale_factor, 1::self.scale_factor] = 0.
                    else:
                        phase_scaled = phase
                    laser_power = laser_powers[frame_id][channel_id]
                    hologram = generate_complex_field(
                                                      laser_power * self.amplitude,
                                                      phase_scaled * self.phase_scale[channel_id]
                                                     )
                    reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                    intensity = calculate_amplitude(reconstruction_field) ** 2
                    reconstruction_intensities[frame_id, channel_id] += intensity
                hologram_phases[frame_id] = phase.detach().clone()
            loss_laser = self.l2_loss(
                                      torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
                                      torch.sum(laser_powers, dim = 0)
                                     )
            loss_laser += self.l2_loss(
                                       torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
                                       torch.sum(laser_powers).view(1,)
                                      )
            loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
            reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
            loss_image = self.evaluate(
                                       reconstruction_intensity,
                                       depth_target * self.peak_amplitude,
                                       plane_id = depth_id
                                      )
            loss = weights[0] * loss_image
            loss += weights[1] * loss_laser
            loss += weights[2] * loss_variation_hologram
            include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
            if include_pa_loss_flag:
                loss -= self.peak_amplitude * 1.
            if self.method == 'conventional':
                loss.backward()
            else:
                loss.backward(retain_graph = True)
            self.optimizer.step()
            if include_pa_loss_flag:
                peak_amp_cache = self.peak_amplitude.item()
            else:
                with torch.no_grad():
                    if self.optimize_peak_amplitude:
                        self.peak_amplitude.view([1])[0] = peak_amp_cache
            total_loss += loss.detach().item()
            loss_image = loss_image.detach()
            del loss_laser
            del loss_variation_hologram
            del loss
        description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
        t.set_description(description)
        del total_loss
        del loss_image
        del reconstruction_field
        del reconstruction_intensities
        del intensity
        del phase
        del hologram
    logging.warning(description)
    return hologram_phases.detach()

init_amplitude()

Internal function to set the amplitude of the illumination source.

Source code in odak/learn/wave/optimizers.py
def init_amplitude(self):
    """
    Internal function to set the amplitude of the illumination source.
    """
    self.amplitude = torch.ones(
                                self.resolution[0],
                                self.resolution[1],
                                requires_grad = False,
                                device = self.device
                               )

init_channel_power()

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_channel_power(self):
    """
    Internal function to set the starting phase of the phase-only hologram.
    """
    if self.method == 'conventional':
        logging.warning('Scheme: Conventional')
        self.channel_power = torch.eye(
                                       self.number_of_frames,
                                       self.number_of_channels,
                                       device = self.device,
                                       requires_grad = False
                                      )

    elif self.method == 'multi-color':
        logging.warning('Scheme: Multi-color')
        self.channel_power = torch.ones(
                                        self.number_of_frames,
                                        self.number_of_channels,
                                        device = self.device,
                                        requires_grad = True
                                       )
    if self.channel_power_filename != '':
        self.channel_power = torch_load(self.channel_power_filename).to(self.device)
        self.channel_power.requires_grad = False
        self.channel_power[self.channel_power < 0.] = 0.
        self.channel_power[self.channel_power > 1.] = 1.
        if self.method == 'multi-color':
            self.channel_power.requires_grad = True
        if self.method == 'conventional':
            self.channel_power = torch.abs(torch.cos(self.channel_power))
        logging.warning('Channel powers:')
        logging.warning(self.channel_power)
        logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
    self.propagator.set_laser_powers(self.channel_power)

init_loss_function(loss_function, reduction='sum')

Internal function to set the loss function.

Source code in odak/learn/wave/optimizers.py
def init_loss_function(self, loss_function, reduction = 'sum'):
    """
    Internal function to set the loss function.
    """
    self.l2_loss = torch.nn.MSELoss(reduction = reduction)
    self.loss_type = 'custom'
    self.loss_function = loss_function
    if isinstance(self.loss_function, type(None)):
        self.loss_type = 'conventional'
        self.loss_function = torch.nn.MSELoss(reduction = reduction)

init_optimizer()

Internal function to set the optimizer.

Source code in odak/learn/wave/optimizers.py
def init_optimizer(self):
    """
    Internal function to set the optimizer.
    """
    optimization_variables = [self.phase, self.offset]
    if self.optimize_peak_amplitude:
        optimization_variables.append(self.peak_amplitude)
    if self.method == 'multi-color':
        optimization_variables.append(self.propagator.channel_power)
    self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)

init_peak_amplitude_scale()

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_peak_amplitude_scale(self):
    """
    Internal function to set the phase scale.
    """
    self.peak_amplitude = torch.tensor(
                                       self.peak_amplitude,
                                       requires_grad = True,
                                       device=self.device
                                      )

init_phase()

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_phase(self):
    """
    Internal function to set the starting phase of the phase-only hologram.
    """
    self.phase = torch.zeros(
                             self.number_of_frames,
                             self.resolution[0],
                             self.resolution[1],
                             device = self.device,
                             requires_grad = True
                            )
    self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)

init_phase_scale()

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_phase_scale(self):
    """
    Internal function to set the phase scale.
    """
    if self.method == 'conventional':
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )
    if self.method == 'multi-color':
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )

optimize(number_of_iterations=100, weights=[1.0, 1.0, 1.0], bits=8)

Function to optimize multiplane phase-only holograms.

Parameters:

  • number_of_iterations
                         Number of iterations.
    
  • weights
                         Loss weights.
    
  • bits
                         Quantizes the hologram using the given bits and reconstructs.
    

Returns:

  • hologram_phases ( tensor ) –

    Phases of the optimized phase-only hologram.

  • reconstruction_intensities ( tensor ) –

    Intensities of the images reconstructed at each plane with the optimized phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):
    """
    Function to optimize multiplane phase-only holograms.

    Parameters
    ----------
    number_of_iterations       : int
                                 Number of iterations.
    weights                    : list
                                 Loss weights.
    bits                       : int
                                 Quantizes the hologram using the given bits and reconstructs.

    Returns
    -------
    hologram_phases            : torch.tensor
                                 Phases of the optimized phase-only hologram.
    reconstruction_intensities : torch.tensor
                                 Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
    """
    self.init_optimizer()
    hologram_phases = self.gradient_descent(
                                            number_of_iterations=number_of_iterations,
                                            weights=weights
                                           )
    hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi
    torch.no_grad()
    reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
    laser_powers = self.propagator.get_laser_powers()
    channel_powers = self.propagator.channel_power
    logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
    logging.warning('Laser powers: {}'.format(laser_powers))
    return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)

propagator

A light propagation model that propagates light to desired image plane with two separate propagations. We use this class in our various works including Kavaklı et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography.

Source code in odak/learn/wave/propagators.py
class propagator():
    """
    A light propagation model that propagates light to desired image plane with two separate propagations. 
    We use this class in our various works including `Kavaklı et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography`.
    """
    def __init__(
                 self,
                 resolution = [1920, 1080],
                 wavelengths = [515e-9,],
                 pixel_pitch = 8e-6,
                 resolution_factor = 1,
                 number_of_frames = 1,
                 number_of_depth_layers = 1,
                 volume_depth = 1e-2,
                 image_location_offset = 5e-3,
                 propagation_type = 'Bandlimited Angular Spectrum',
                 propagator_type = 'back and forth',
                 back_and_forth_distance = 0.3,
                 laser_channel_power = None,
                 aperture = None,
                 aperture_size = None,
                 method = 'conventional',
                 device = torch.device('cpu')
                ):
        """
        Parameters
        ----------
        resolution              : list
                                  Resolution.
        wavelengths             : float
                                  Wavelength of light in meters.
        pixel_pitch             : float
                                  Pixel pitch in meters.
        resolution_factor       : int
                                  Resolution factor for scaled simulations.
        number_of_frames        : int
                                  Number of hologram frames.
                                  Typically, there are three frames, each one for a single color primary.
        number_of_depth_layers  : int
                                  Equ-distance number of depth layers within the desired volume.
        volume_depth            : float
                                  Width of the volume along the propagation direction.
        image_location_offset   : float
                                  Center of the volume along the propagation direction.
        propagation_type        : str
                                  Propagation type. 
                                  See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
        propagator_type         : str
                                  Propagator type.
                                  The options are `back and forth` and `forward` propagators.
        back_and_forth_distance : float
                                  Zero mode distance for `back and forth` propagator type.
        laser_channel_power     : torch.tensor
                                  Laser channel powers for given number of frames and number of wavelengths.
        aperture                : torch.tensor
                                  Aperture at the Fourier plane.
        method                  : str
                                  Hologram type conventional or multi-color.
        device                  : torch.device
                                  Device to be used for computation. For more see torch.device().
        """
        self.device = device
        self.pixel_pitch = pixel_pitch
        self.wavelengths = wavelengths
        self.resolution = resolution
        self.resolution_factor = resolution_factor
        self.number_of_frames = number_of_frames
        self.number_of_depth_layers = number_of_depth_layers
        self.number_of_channels = len(self.wavelengths)
        self.volume_depth = volume_depth
        self.image_location_offset = image_location_offset
        self.propagation_type = propagation_type
        self.propagator_type = propagator_type
        self.zero_mode_distance = torch.tensor(back_and_forth_distance, device = device)
        self.method = method
        self.aperture = aperture
        self.init_distances()
        self.init_kernels()
        self.init_channel_power(laser_channel_power)
        self.init_phase_scale()
        self.set_aperture(aperture, aperture_size)


    def init_distances(self):
        """
        Internal function to initialize distances.
        """
        self.distances = torch.linspace(-self.volume_depth / 2., self.volume_depth / 2., self.number_of_depth_layers) + self.image_location_offset
        logging.warning('Distances: {}'.format(self.distances))


    def init_kernels(self):
        """
        Internal function to initialize kernels.
        """
        self.generated_kernels = torch.zeros(
                                             self.number_of_depth_layers,
                                             self.number_of_channels,
                                             device = self.device
                                            )
        self.kernels = torch.zeros(
                                   self.number_of_depth_layers,
                                   self.number_of_channels,
                                   self.resolution[0] * self.resolution_factor * 2,
                                   self.resolution[1] * self.resolution_factor * 2,
                                   dtype = torch.complex64,
                                   device = self.device
                                  )


    def init_channel_power(self, channel_power):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        self.channel_power = channel_power
        if isinstance(self.channel_power, type(None)):
            self.channel_power = torch.eye(
                                           self.number_of_frames,
                                           self.number_of_channels,
                                           device = self.device,
                                           requires_grad = False
                                          )


    def init_phase_scale(self):
        """
        Internal function to set the phase scale.
        In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.
        """
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )


    def set_aperture(self, aperture = None, aperture_size = None):
        """
        Set aperture in the Fourier plane.


        Parameters
        ----------
        aperture        : torch.tensor
                          Aperture at the original resolution of a hologram.
                          If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).
        aperture_size   : int
                          If no aperture is provided, this will determine the size of the circular aperture.
        """
        if isinstance(aperture, type(None)):
            if isinstance(aperture_size, type(None)):
                aperture_size = torch.max(
                                          torch.tensor([
                                                        self.resolution[0] * self.resolution_factor, 
                                                        self.resolution[1] * self.resolution_factor
                                                       ])
                                         )
            self.aperture = circular_binary_mask(
                                                 self.resolution[0] * self.resolution_factor * 2,
                                                 self.resolution[1] * self.resolution_factor * 2,
                                                 aperture_size,
                                                ).to(self.device) * 1.
        else:
            self.aperture = zero_pad(aperture).to(self.device) * 1.


    def get_laser_powers(self):
        """
        Internal function to get the laser powers.

        Returns
        -------
        laser_power      : torch.tensor
                           Laser powers.
        """
        if self.method == 'conventional':
            laser_power = self.channel_power
        if self.method == 'multi-color':
            laser_power = torch.abs(torch.cos(self.channel_power))
        return laser_power


    def set_laser_powers(self, laser_power):
        """
        Internal function to set the laser powers.

        Parameters
        -------
        laser_power      : torch.tensor
                           Laser powers.
        """
        self.channel_power = laser_power



    def get_kernels(self):
        """
        Function to return the kernels used in the light transport.

        Returns
        -------
        kernels           : torch.tensor
                            Kernel amplitudes.
        """
        h = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(self.kernels)))
        kernels_amplitude = calculate_amplitude(h)
        kernels_phase = calculate_phase(h)
        return kernels_amplitude, kernels_phase


    def propagate(self, field, H):
        """
        Internal function used in propagation. It is a copy of odak.learn.wave.band_limited_angular_spectrum().
        """
        field_padded = zero_pad(field)
        U1 = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(field_padded)))
        U2 = H * self.aperture * U1
        result_padded = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U2)))
        result = crop_center(result_padded)
        return result


    def __call__(self, input_field, channel_id, depth_id):
        """
        Function that represents the forward model in hologram optimization.

        Parameters
        ----------
        input_field         : torch.tensor
                              Input complex input field.
        channel_id          : int
                              Identifying the color primary to be used.
        depth_id            : int
                              Identifying the depth layer to be used.

        Returns
        -------
        output_field        : torch.tensor
                              Propagated output complex field.
        """
        distance = self.distances[depth_id]
        if not self.generated_kernels[depth_id, channel_id]:
            if self.propagator_type == 'forward':
                H = get_propagation_kernel(
                                           nu = input_field.shape[-2] * 2,
                                           nv = input_field.shape[-1] * 2,
                                           dx = self.pixel_pitch,
                                           wavelength = self.wavelengths[channel_id],
                                           distance = distance,
                                           device = self.device,
                                           propagation_type = self.propagation_type,
                                           scale = self.resolution_factor
                                          )
            elif self.propagator_type == 'back and forth':
                H_forward = get_propagation_kernel(
                                                   nu = input_field.shape[-2] * 2,
                                                   nv = input_field.shape[-1] * 2,
                                                   dx = self.pixel_pitch,
                                                   wavelength = self.wavelengths[channel_id],
                                                   distance = self.zero_mode_distance,
                                                   device = self.device,
                                                   propagation_type = self.propagation_type,
                                                   scale = self.resolution_factor
                                                  )
                distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
                H_back = get_propagation_kernel(
                                                nu = input_field.shape[-2] * 2,
                                                nv = input_field.shape[-1] * 2,
                                                dx = self.pixel_pitch,
                                                wavelength = self.wavelengths[channel_id],
                                                distance = distance_back,
                                                device = self.device,
                                                propagation_type = self.propagation_type,
                                                scale = self.resolution_factor
                                               )
                H = H_forward * H_back
            self.kernels[depth_id, channel_id] = H
            self.generated_kernels[depth_id, channel_id] = True
        else:
            H = self.kernels[depth_id, channel_id].detach().clone()
        output_field = self.propagate(input_field, H)
        return output_field


    def reconstruct(self, hologram_phases, amplitude = None, no_grad = True):
        """
        Internal function to reconstruct a given hologram.


        Parameters
        ----------
        hologram_phases            : torch.tensor
                                     Hologram phases [ch x m x n].
        amplitude                  : torch.tensor
                                     Amplitude profiles for each color primary [ch x m x n]
        no_grad                    : bool
                                     If set True, uses torch.no_grad in reconstruction.

        Returns
        -------
        reconstruction_intensities : torch.tensor
                                     Reconstructed frames.
        """
        if no_grad:
            torch.no_grad()
        if len(hologram_phases.shape) > 3:
            hologram_phases = hologram_phases.squeeze(0)
        reconstruction_intensities = torch.zeros(
                                                 self.number_of_frames,
                                                 self.number_of_depth_layers,
                                                 self.number_of_channels,
                                                 self.resolution[0] * self.resolution_factor,
                                                 self.resolution[1] * self.resolution_factor,
                                                 device = self.device
                                                )
        if isinstance(amplitude, type(None)):
            amplitude = torch.ones(
                                   self.number_of_channels,
                                   self.resolution[0] * self.resolution_factor,
                                   self.resolution[1] * self.resolution_factor,
                                   device = self.device
                                  )
        for frame_id in range(self.number_of_frames):
            for depth_id in range(self.number_of_depth_layers):
                for channel_id in range(self.number_of_channels):
                    laser_power = self.get_laser_powers()[frame_id][channel_id]
                    if self.resolution_factor != 1:
                        phase = torch.zeros_like(amplitude)
                        phase[::self.resolution_factor, ::self.resolution_factor] = phase
                        amplitude[1::self.resolution_factor, 1::self.resolution_factor] = 0.
                    else:
                        phase = hologram_phases[frame_id]
                    hologram = generate_complex_field(
                                                      laser_power * amplitude[channel_id],
                                                      phase * self.phase_scale[channel_id]
                                                     )
                    reconstruction_field = self.__call__(hologram, channel_id, depth_id)
                    reconstruction_intensities[
                                               frame_id,
                                               depth_id,
                                               channel_id
                                              ] = calculate_amplitude(reconstruction_field).detach().clone() ** 2
        return reconstruction_intensities

__call__(input_field, channel_id, depth_id)

Function that represents the forward model in hologram optimization.

Parameters:

  • input_field
                  Input complex input field.
    
  • channel_id
                  Identifying the color primary to be used.
    
  • depth_id
                  Identifying the depth layer to be used.
    

Returns:

  • output_field ( tensor ) –

    Propagated output complex field.

Source code in odak/learn/wave/propagators.py
def __call__(self, input_field, channel_id, depth_id):
    """
    Function that represents the forward model in hologram optimization.

    Parameters
    ----------
    input_field         : torch.tensor
                          Input complex input field.
    channel_id          : int
                          Identifying the color primary to be used.
    depth_id            : int
                          Identifying the depth layer to be used.

    Returns
    -------
    output_field        : torch.tensor
                          Propagated output complex field.
    """
    distance = self.distances[depth_id]
    if not self.generated_kernels[depth_id, channel_id]:
        if self.propagator_type == 'forward':
            H = get_propagation_kernel(
                                       nu = input_field.shape[-2] * 2,
                                       nv = input_field.shape[-1] * 2,
                                       dx = self.pixel_pitch,
                                       wavelength = self.wavelengths[channel_id],
                                       distance = distance,
                                       device = self.device,
                                       propagation_type = self.propagation_type,
                                       scale = self.resolution_factor
                                      )
        elif self.propagator_type == 'back and forth':
            H_forward = get_propagation_kernel(
                                               nu = input_field.shape[-2] * 2,
                                               nv = input_field.shape[-1] * 2,
                                               dx = self.pixel_pitch,
                                               wavelength = self.wavelengths[channel_id],
                                               distance = self.zero_mode_distance,
                                               device = self.device,
                                               propagation_type = self.propagation_type,
                                               scale = self.resolution_factor
                                              )
            distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
            H_back = get_propagation_kernel(
                                            nu = input_field.shape[-2] * 2,
                                            nv = input_field.shape[-1] * 2,
                                            dx = self.pixel_pitch,
                                            wavelength = self.wavelengths[channel_id],
                                            distance = distance_back,
                                            device = self.device,
                                            propagation_type = self.propagation_type,
                                            scale = self.resolution_factor
                                           )
            H = H_forward * H_back
        self.kernels[depth_id, channel_id] = H
        self.generated_kernels[depth_id, channel_id] = True
    else:
        H = self.kernels[depth_id, channel_id].detach().clone()
    output_field = self.propagate(input_field, H)
    return output_field

__init__(resolution=[1920, 1080], wavelengths=[5.15e-07], pixel_pitch=8e-06, resolution_factor=1, number_of_frames=1, number_of_depth_layers=1, volume_depth=0.01, image_location_offset=0.005, propagation_type='Bandlimited Angular Spectrum', propagator_type='back and forth', back_and_forth_distance=0.3, laser_channel_power=None, aperture=None, aperture_size=None, method='conventional', device=torch.device('cpu'))

Parameters:

  • resolution
                      Resolution.
    
  • wavelengths
                      Wavelength of light in meters.
    
  • pixel_pitch
                      Pixel pitch in meters.
    
  • resolution_factor
                      Resolution factor for scaled simulations.
    
  • number_of_frames
                      Number of hologram frames.
                      Typically, there are three frames, each one for a single color primary.
    
  • number_of_depth_layers
                      Equ-distance number of depth layers within the desired volume.
    
  • volume_depth
                      Width of the volume along the propagation direction.
    
  • image_location_offset
                      Center of the volume along the propagation direction.
    
  • propagation_type
                      Propagation type. 
                      See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
    
  • propagator_type
                      Propagator type.
                      The options are `back and forth` and `forward` propagators.
    
  • back_and_forth_distance (float, default: 0.3 ) –
                      Zero mode distance for `back and forth` propagator type.
    
  • laser_channel_power
                      Laser channel powers for given number of frames and number of wavelengths.
    
  • aperture
                      Aperture at the Fourier plane.
    
  • method
                      Hologram type conventional or multi-color.
    
  • device
                      Device to be used for computation. For more see torch.device().
    
Source code in odak/learn/wave/propagators.py