Skip to content

Commit

Permalink
Apply changes to documentation to conform to pydocstyle
Browse files Browse the repository at this point in the history
  • Loading branch information
noc0lour committed Jan 10, 2025
1 parent 11d02a3 commit 40426a6
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 20 deletions.
142 changes: 129 additions & 13 deletions src/mokka/channels/torch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Channels sub-module implemented within the PyTorch framework."""

import torch

import torchaudio
Expand Down Expand Up @@ -968,7 +969,6 @@ def forward(self, ux, uy):
:param uy: input signal in y-polarization.
:returns: signal at the end of the fiber
"""

w = (2 * torch.pi * torch.fft.fftfreq(ux.shape[0], self.dt * 1e12)).to(
self.betapa.device
) # THz
Expand Down Expand Up @@ -1380,14 +1380,25 @@ def classical_demux(self, signal):

class PMDElement(torch.nn.Module):
"""
Static and dynamic PMD Element according to Czegledi (2016)
Static and dynamic PMD Element according to Czegledi (2016).
Note: for SSFM only static PMD is correct
"""

def __init__(
self, sigma_p, pmd_parameter, span_length, steps_per_span, method="static"
):
"""
Initialize :py:class:`PMDElement`.
:param sigma_p: Variance of time-varying Wiener process
:param pmd_parameter: Calculate second moment of the DGD based
on the PMD parameter
:param span_length: Length of each span of optical fiber
:param steps_per_span: Simulation steps per span
:param method: Either "static" or "dynamic" - sets the mode
to either time-varying or fixed in time
"""
super(PMDElement, self).__init__()
# Apply PMD using matrix J_k which can be calculated from J(\alpha_k)
# Pauli spin matrices
Expand Down Expand Up @@ -1420,14 +1431,17 @@ def __init__(

@property
def a(self):
"""Calculate and return parameters a."""
return self.alpha / self.theta

@property
def theta(self):
"""Calculate theta as norm of vector alpha."""
return torch.linalg.vector_norm(self.alpha)

@property
def J(self):
"""Calculate the Jones matrix of the rotation portion."""
return torch.matmul(
torch.eye(2, dtype=torch.complex64) * torch.cos(self.theta)
- 1j
Expand All @@ -1437,13 +1451,15 @@ def J(self):
)

def step(self):
"""Perform one time-step for the time-varying simulation and return `J_k1`."""
if self.sigma_p == 0.0:
return self.J_k1
self.alpha = torch.zeros((3,), dtype=torch.float32).normal_() * self.sigma_p
self.J_k1 = self.J # This applies the J_delta approach to update J_k1
return self.J_k1

def steps(self, k=1):
"""Perform `k` time-varying steps and return the Jones matrices."""
if self.sigma_p == 0.0:
return self.J_k1.expand(k, -1, -1)
alphas = torch.zeros((k, 3), dtype=torch.float32).normal_() * self.sigma_p
Expand All @@ -1461,9 +1477,11 @@ def steps(self, k=1):
return torch.stack(results)

def forward_static(self, signal: torch.Tensor):
"""Apply the static PMD simulation to the input signal."""
return torch.matmul(self.J_k1, signal)

def forward_dynamic(self, signal: torch.Tensor):
"""Apply the time-varying PMD simulation to the input signal."""
num_steps = signal.shape[1]
J_k = self.steps(num_steps)
signal_out = torch.bmm(J_k, signal.unsqueeze(-1)).squeeze()
Expand All @@ -1482,16 +1500,23 @@ def forward(self, signal: torch.Tensor):


class FixedChannelDP(torch.nn.Module):
"""
Apply a fixed channel impulse response on both polarization separately
"""
"""Apply a fixed channel impulse response on both polarization separately."""

def __init__(self, impulse_response):
"""
Initialize :py:class:`FixedChannelDP`.
:param impulse_response: single polarization impulse response
"""
super(FixedChannelDP, self).__init__()

self.impulse_response = torch.as_tensor(impulse_response)

def forward(self, tx_signal):
"""Apply static dual polarization signal to `tx_signal`.
:param tx_signal: dual polarization input signal
"""
return torch.cat(
(
convolve(tx_signal[0, :], self.impulse_response, mode="full").unsqueeze(
Expand All @@ -1507,15 +1532,23 @@ def forward(self, tx_signal):

class FixedArbitraryChannelDP(torch.nn.Module):
"""
Apply a fixed channel impulse response on both polarization separately
Apply a fixed 2x2 channel impulse response to a dual polarization signal.
This class only implements a time-invariant dual-polarization channel.
"""

def __init__(self, impulse_response):
"""
Initialize :py:class:`FixedArbitraryChannelDP`.
:param impulse_response: arbitrary 2x2 impulse response.
"""
super(FixedArbitraryChannelDP, self).__init__()

self.impulse_response = torch.as_tensor(impulse_response)

def forward(self, tx_signal):
"""Apply arbitrary dual polarization channel to `tx_signal`."""
return torch.stack(
(
convolve(tx_signal[0, :], self.impulse_response[0, :], mode="full")
Expand All @@ -1528,18 +1561,43 @@ def forward(self, tx_signal):

class FixedChannelSP(torch.nn.Module):
"""
Apply a fixed channel impulse response for a single polarization
Apply a fixed channel impulse response to a single polarization input signal.
This class only implements a time-invariant single-polarization channel.
"""

def __init__(self, impulse_response):
"""
Initialize :py:class:`FixedChannelSP`.
:param impulse_response: 1xN vector of complex time-domain samples of the impulse response
"""
super(FixedChannelSP, self).__init__()
self.impulse_response = torch.as_tensor(impulse_response)

def forward(self, tx_signal):
"""
Apply fixed single-polarization channel on `tx_signal`.
:param tx_signal: Single polarization complex signal vector
"""
return convolve(tx_signal, self.impulse_response, mode="full")


def ProakisChannel(variant, sps=1):
"""
Return impulse response for channels defined in [0].
:param variant: Either "a", "b", "c" or "a_complex".
In the literature all channels are real-valued. The complex channel "a_complex"
is an extension of the "a" channel by applying a phase rotation of the given
samples.
:param sps: samples-per-symbol choosing an integer value greater than 1 will
add sps-1 zeros in-between the channel given channel taps. No band-limiting filter
is applied subsequently.
[0] Proakis, John G., and Masoud Salehi. Digital communications. McGraw-hill, 2008.
"""
if variant == "a":
h = torch.tensor(
[0.04, -0.05, 0.07, -0.21, -0.5, 0.72, 0.36, 0.0, 0.21, 0.03, 0.07],
Expand Down Expand Up @@ -1574,12 +1632,17 @@ def ProakisChannel(variant, sps=1):

class PDLElement(torch.nn.Module):
"""
Simulate PDL in optical channels
Simulate PDL in optical channels.
This class applies PDL which is randomly rotated w.r.t to
the input signal.
"""

def __init__(self, rho):
"""
Construct a PDL element which exhibits a differential linear
Construct a PDL element.
It exhibits a differential linear
attenuation of rho. To get the attenuation matrix Gamma we
first set attenuation of one polarization to sqrt(1+rho)
and the other to sqrt(1-rho) and then rotate
Expand All @@ -1598,6 +1661,7 @@ def __init__(self, rho):
self.Gamma = rot @ Gamma

def forward(self, signal):
"""Apply time-invariant PDL by mutiplying input signal and Jones matrix."""
return torch.matmul(self.Gamma, signal)


Expand All @@ -1610,6 +1674,16 @@ class DPImpairments(torch.nn.Module):
"""

def __init__(self, samp_rate, tau_cd, tau_pmd, phi_IQ, theta, rho=0):
r"""
Initialize py:class:`DPImpairments`.
:param samp_rate: sampling rate of the input signal
:param tau_cd: Residual chromatic dispersion coefficient \tau_{cd}
:param tau_pmd: Residual polarization-mode dispersion coefficient \tau_{pmd}
:param phi_IQ: Phase rotation \phi_{IQ}
:param theta: Polarization angle \theta
:param rho: Polarization angle \rho
"""
super(DPImpairments, self).__init__()
self.samp_rate = torch.as_tensor(samp_rate)
self.tau_cd = torch.as_tensor(tau_cd)
Expand All @@ -1621,7 +1695,7 @@ def __init__(self, samp_rate, tau_cd, tau_pmd, phi_IQ, theta, rho=0):

def forward(self, signal):
"""
Apply DPImpairment to a dual polarization signal
Apply DPImpairment to a dual polarization signal.
:param signal: Must be a 2xN complex-valued PyTorch tensor
"""
Expand Down Expand Up @@ -1677,9 +1751,7 @@ def forward(self, signal):


class PMDPDLChannel(torch.nn.Module):
"""
Optical channel with only PMD and PDL impairments.
"""
"""Optical channel with only PMD and PDL impairments."""

def __init__(
self,
Expand All @@ -1694,6 +1766,20 @@ def __init__(
pdl_min=0.1,
method="freq",
):
"""
Initialize :py:class:`PMDPDLChannel`.
:param L:
:param num_steps:
:param pmd_parameter:
:param pmd_correlation_length:
:param f_samp:
:param pmd_sigma:
:param num_pdl_elements:
:param pdl_max:
:param pdl_min:
:param method:
"""
super(PMDPDLChannel, self).__init__()
self.dz = torch.as_tensor(L / num_steps)
self.dt = 1.0 / f_samp
Expand All @@ -1720,10 +1806,19 @@ def __init__(
self.method = method

def step(self):
"""Propagate the PMD Elements in time.
Only relevant for time-variant PMD.
"""
for pe in self.pmd_elements:
_ = pe.step()

def forward(self, u):
"""
Apply Channel to input signal.
:param u: Dual-polarization complex input signal
"""
if self.method == "freq":
return self.forward_freq(u)
elif self.method == "time":
Expand All @@ -1732,6 +1827,11 @@ def forward(self, u):
raise ValueError("self.method must be either freq or time")

def forward_time(self, u):
"""
Apply channel in time domain.
:param u: Dual-polarization complex input signal
"""
w = (2 * torch.pi * torch.fft.fftfreq(u.shape[1], self.dt * 1e12)).to(
self.betapa.device
) # THz
Expand Down Expand Up @@ -1765,13 +1865,23 @@ def forward_time(self, u):
return u

def forward_freq(self, u):
"""
Apply channel in frequency domain.
:param u: Dual-polarization complex input signal
"""
# We perform the full simulation in the frequency domain
u_f = torch.fft.fft(u, dim=1)
u_f = self._forward_freq(u_f)
u = torch.fft.ifft(u_f, dim=1)
return u

def _forward_freq(self, u_f):
"""
Apply the channel on the signal already in f-domain.
:param u_f: Dual-polarization complex input signal in frequency domain.
"""
w = (2 * torch.pi * torch.fft.fftfreq(u_f.shape[1], self.dt * 1e12)).to(
self.betapa.device
) # THz
Expand Down Expand Up @@ -1803,6 +1913,12 @@ def _forward_freq(self, u_f):
return u_f

def channel_transfer(self, length, pulse_shape=None):
"""
Compute the 2x2 channel transfer function.
:param length: One-sided length of the desired impulse response
:param pulse_shape: Shaping to apply to the window function
"""
phase_shift = 2 * torch.pi * torch.fft.fftfreq(length, 1) * (length // 2 + 1)
u_f = torch.zeros((2, length), dtype=torch.complex64)
u_f[0, :] = torch.ones(length, dtype=torch.complex64) * torch.exp(
Expand Down
11 changes: 8 additions & 3 deletions src/mokka/equalizers/adaptive/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def reset(self):

def forward(self, y):
"""
Equalize input signal y
Equalize input signal y.
:param y: Complex receive signal y
"""
Expand Down Expand Up @@ -818,7 +818,9 @@ def forward(self, y):

class PilotAEQ_SP(torch.nn.Module):
"""
Perform pilot-based adaptive equalization (QPSK)
Perform pilot-based adaptive equalization (QPSK).
This class performs the adaptive equalization for a single polarization.
"""

def __init__(
Expand Down Expand Up @@ -853,7 +855,7 @@ def __init__(
self.method = method

def reset(self):
"""Reset :py:class:`PilotAEQ_SP`"""
"""Reset :py:class:`PilotAEQ_SP`."""
self.taps.zero_()
self.taps[self.taps.size()[0] // 2] = 1.0

Expand Down Expand Up @@ -956,10 +958,12 @@ def __init__(
# Do some clever initalization, first only equalize x-pol and then enable y-pol

def reset(self):
"""Reset :py:class:`AEQ_SP`."""
self.taps.zero_()
self.taps[self.taps.shape[0] // 2] = 1.0

def forward(self, y):
"""Perform adaptive equalization."""
# Implement CMA "by hand"
# Basically step through the signal advancing always +sps symbols
# and filtering 2*filter_len samples which will give one output sample with
Expand Down Expand Up @@ -995,4 +999,5 @@ def forward(self, y):
return out

def get_error_signal(self):
"""Extract the error signal."""
return self.out_e
3 changes: 2 additions & 1 deletion src/mokka/equalizers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def forward(self, y, center_freq=0):
class LinearFilter(torch.nn.Module):
"""Class implementing a SISO linear filter.
Optionally with trainable filter_taps."""
Optionally with trainable filter_taps.
"""

def __init__(
self,
Expand Down
Loading

0 comments on commit 40426a6

Please sign in to comment.