Skip to content

Commit

Permalink
Merge pull request #5 from marcus-k/new-api
Browse files Browse the repository at this point in the history
Add new NWN class
  • Loading branch information
marcus-k authored Jul 11, 2024
2 parents ec04014 + b09e3c5 commit e882da9
Show file tree
Hide file tree
Showing 14 changed files with 899 additions and 270 deletions.
1 change: 1 addition & 0 deletions dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
dependencies:
- ipykernel
- joblib
- dill
- line_profiler
- matplotlib=3.8.4
- networkx=3.1
Expand Down
3 changes: 2 additions & 1 deletion randomnwn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"""
from .version import __version__

from .nanowire_network import create_NWN

from .nanowires import (
create_NWN,
convert_NWN_to_MNR,
add_wires,
add_electrodes,
Expand Down
48 changes: 27 additions & 21 deletions randomnwn/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,24 @@
# Author: Marcus Kasdorf
# Date: July 28, 2021

from __future__ import annotations

import numpy as np
import networkx as nx
from typing import Callable, List, Union, Tuple
from numbers import Number

import numpy.typing as npt
from .typing import *
from typing import Callable, TYPE_CHECKING
if TYPE_CHECKING:
from .nanowire_network import NanowireNetwork

from .calculations import solve_network


def resist_func(
NWN: nx.Graph,
w: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
NWN: NanowireNetwork,
w: float | npt.NDArray
) -> float | npt.NDArray:
"""
The HP group's resistance function in nondimensionalized form.
Expand All @@ -48,18 +54,18 @@ def resist_func(

def _HP_model_no_decay(
t: float,
w: np.ndarray,
NWN: nx.Graph,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
w: npt.NDArray,
NWN: NanowireNetwork,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage_func: Callable,
edge_list: list,
start_nodes: list,
end_nodes: list,
window_func: Callable,
solver: str = "spsolve",
kwargs: dict = None
) -> np.ndarray:
) -> npt.NDArray:
"""
Derivative of the nondimensionalized state variables `w`.
Expand Down Expand Up @@ -97,18 +103,18 @@ def _HP_model_no_decay(

def _HP_model_decay(
t: float,
w: np.ndarray,
NWN: nx.Graph,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
w: npt.NDArray,
NWN: NanowireNetwork,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage_func: Callable,
edge_list: list,
start_nodes: list,
end_nodes: list,
window_func: Callable,
solver: str = "spsolve",
kwargs: dict = None
) -> np.ndarray:
) -> npt.NDArray:
"""
Derivative of the nondimensionalized state variables `w` with
decay value `tau`.
Expand Down Expand Up @@ -150,18 +156,18 @@ def _HP_model_decay(

def _HP_model_chen(
t: float,
y: np.ndarray,
NWN: nx.Graph,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
y: npt.NDArray,
NWN: NanowireNetwork,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage_func: Callable,
edge_list: list,
start_nodes: list,
end_nodes: list,
window_func: Callable,
solver: str = "spsolve",
kwargs: dict = None
) -> np.ndarray:
) -> npt.NDArray:
"""
Derivative of the nondimensionalized state variables `w`, `tau`, and
`epsilon`.
Expand Down Expand Up @@ -206,7 +212,7 @@ def _HP_model_chen(
return dydt


def set_chen_params(NWN: nx.Graph, sigma: Number, theta: Number, a: Number):
def set_chen_params(NWN: NanowireNetwork, sigma, theta, a):
NWN.graph["sigma"] = sigma
NWN.graph["theta"] = theta
NWN.graph["a"] = a
66 changes: 36 additions & 30 deletions randomnwn/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
# Author: Marcus Kasdorf
# Date: July 19, 2021

from __future__ import annotations

import numpy as np
import numpy.typing as npt
import scipy
import networkx as nx
from networkx.linalg import laplacian_matrix
from typing import List, Tuple, Set, Union

from .typing import *
from .nanowires import get_edge_indices
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .nanowire_network import NanowireNetwork


def get_connected_nodes(NWN: nx.Graph, connected: List[Tuple]) -> Set[Tuple]:
def get_connected_nodes(NWN: NanowireNetwork, connected: list[NWNNode]) -> set[NWNNode]:
"""
Returns a list of nodes which are connected to any of the given nodes.
Expand All @@ -28,10 +34,10 @@ def get_connected_nodes(NWN: nx.Graph, connected: List[Tuple]) -> Set[Tuple]:


def create_matrix(
NWN: nx.Graph,
NWN: NanowireNetwork,
value_type: str = "conductance",
source_nodes: List[Tuple] = None,
drain_nodes: List[Tuple] = None,
source_nodes: list[NWNNode] = None,
drain_nodes: list[NWNNode] = None,
ground_nodes: bool = False
) -> scipy.sparse.csr_matrix:
"""
Expand Down Expand Up @@ -132,13 +138,13 @@ def _solver(A, z, solver, **kwargs):


def _solve_voltage(
NWN: nx.Graph,
NWN: NanowireNetwork,
voltage: float,
source_nodes: List[Tuple],
drain_nodes: List[Tuple],
source_nodes: list[NWNNode],
drain_nodes: list[NWNNode],
solver: str,
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
Solve for voltages at all the nodes for a given supplied voltage.
Expand Down Expand Up @@ -166,13 +172,13 @@ def _solve_voltage(


def _solve_current(
NWN: nx.Graph,
NWN: NanowireNetwork,
current: float,
source_nodes: List[Tuple],
drain_nodes: List[Tuple],
source_nodes: list[NWNNode],
drain_nodes: list[NWNNode],
solver: str,
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
Solve for voltages at all the nodes for a given supplied current.
Expand All @@ -192,14 +198,14 @@ def _solve_current(


def solve_network(
NWN: nx.Graph,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
NWN: NanowireNetwork,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
input: float,
type: str = "voltage",
solver: str = "spsolve",
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
Solve for the voltages of each node in a given NWN. Each drain node will
be grounded. If the type is "voltage", each source node will be at the
Expand Down Expand Up @@ -254,14 +260,14 @@ def solve_network(


def solve_drain_current(
NWN: nx.Graph,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
NWN: NanowireNetwork,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage: float,
scaled: bool = False,
solver: str = "spsolve",
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
Solve for the current through each drain node of a NWN.
Expand Down Expand Up @@ -323,14 +329,14 @@ def solve_drain_current(


def solve_nodal_current(
NWN: nx.Graph,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
NWN: NanowireNetwork,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage: float,
scaled: bool = False,
solver: str = "spsolve",
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
Solve for the current through each node of a NWN. It will appear that
no current is flowing through source (drain) nodes for positive (negative)
Expand Down Expand Up @@ -393,14 +399,14 @@ def solve_nodal_current(


def solve_edge_current(
NWN: nx.Graph,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
NWN: NanowireNetwork,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage: float,
scaled: bool = False,
solver: str = "spsolve",
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
Solve for the current through each node of a NWN. It will appear that
no current is flowing through source (drain) nodes for positive (negative)
Expand Down Expand Up @@ -456,7 +462,7 @@ def solve_edge_current(
return I


def scale_sol(NWN: nx.Graph, sol: np.ndarray):
def scale_sol(NWN: NanowireNetwork, sol: npt.NDArray) -> npt.NDArray:
"""
Scale the voltage and current solutions by their characteristic values.
Expand Down
38 changes: 21 additions & 17 deletions randomnwn/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from scipy.integrate import solve_ivp

from numbers import Number
from typing import Callable, List, Union, Tuple, Iterable
from typing import Callable, Iterable
from scipy.integrate._ivp.ivp import OdeResult
import numpy.typing as npt

from .nanowire_network import NanowireNetwork
from .typing import *

from .nanowires import get_edge_indices
from .calculations import solve_drain_current, solve_network
Expand All @@ -25,17 +29,17 @@


def solve_evolution(
NWN: nx.Graph,
t_eval: np.ndarray,
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
NWN: NanowireNetwork,
t_eval: npt.NDArray,
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage_func: Callable,
window_func: Callable = None,
tol: float = 1e-12,
model: str = "default",
solver: str = "spsolve",
**kwargs
) -> Tuple[OdeResult, List[Tuple]]:
) -> tuple[OdeResult, list[NWNEdge]]:
"""
Solve for the state variables `w` of the junctions of the given nanowire
network at various points in time with an applied voltage.
Expand Down Expand Up @@ -79,7 +83,7 @@ def solve_evolution(
Returns
-------
sol : OdeResult
Output from `scipy.intergrate.solve_ivp`. See the SciPy documentation
Output from `scipy.integrate.solve_ivp`. See the SciPy documentation
for information on this output's formatting.
edge_list : list of tuples
Expand Down Expand Up @@ -144,7 +148,7 @@ def solve_evolution(
return sol, edge_list


def set_state_variables(NWN: nx.Graph, *args):
def set_state_variables(NWN: NanowireNetwork, *args):
"""
Sets the given nanowire network's state variable. Can be called in the
following ways:
Expand Down Expand Up @@ -288,16 +292,16 @@ def set_state_variables(NWN: nx.Graph, *args):


def get_evolution_current(
NWN: nx.Graph,
NWN: NanowireNetwork,
sol: OdeResult,
edge_list: List[Tuple],
source_node: Union[Tuple, List[Tuple]],
drain_node: Union[Tuple, List[Tuple]],
edge_list: list[NWNEdge],
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage_func: Callable,
scaled: bool = False,
solver: str = "spsolve",
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
To be used in conjunction with `solve_evolution`. Takes the output from
`solve_evolution` and finds the current passing through each drain node
Expand Down Expand Up @@ -365,13 +369,13 @@ def get_evolution_current(
def get_evolution_node_voltages(
NWN: nx.Graph,
sol: OdeResult,
edge_list: list[tuple],
source_node: tuple | list[tuple],
drain_node: tuple | list[tuple],
edge_list: list[NWNEdge],
source_node: NWNNode | list[NWNNode],
drain_node: NWNNode | list[NWNNode],
voltage_func: Callable,
solver: str = "spsolve",
**kwargs
) -> np.ndarray:
) -> npt.NDArray:
"""
To be used in conjunction with `solve_evolution`. Takes the output from
`solve_evolution` and finds the voltage of all nodes in the network at each
Expand Down
Loading

0 comments on commit e882da9

Please sign in to comment.