Skip to content

Commit

Permalink
Fix bug in computation of stiffness matrix
Browse files Browse the repository at this point in the history
We need to scale the stiffness matrix with the (original) length of the segment as it represents an infitesimal quantity
  • Loading branch information
mstoelzle committed Nov 19, 2024
1 parent 375d5f7 commit fd05085
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 16 deletions.
15 changes: 7 additions & 8 deletions examples/simulate_planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,20 @@

# set parameters
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
D = 1e-4 * jnp.diag(
jnp.repeat(
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
).flatten(),
)
params = {
"th0": jnp.array(0.0), # initial orientation angle [rad]
"l": 1e-1 * jnp.ones((num_segments,)),
"r": 2e-2 * jnp.ones((num_segments,)),
"rho": rho,
"g": jnp.array([0.0, 9.81]),
"E": 2e2 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
"G": 1e2 * jnp.ones((num_segments,)), # Shear modulus [Pa]
"D": D,
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
}
params["D"] = 1e-3 * jnp.diag(
jnp.repeat(
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
).flatten(),
) * params["l"]

# activate all strains (i.e. bending, shear, and axial)
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ name = "jsrm" # Required
#
# For a discussion on single-sourcing the version, see
# https://packaging.python.org/guides/single-sourcing-package-version/
version = "0.0.11" # Required
version = "0.0.12" # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down
2 changes: 1 addition & 1 deletion src/jsrm/symbolic_derivation/planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def symbolically_derive_planar_pcs_model(
kappa = xi[3 * i]
# shear strain
sigma_x = xi[3 * i + 1]
# elongation strain
# axial strain
sigma_y = xi[3 * i + 2]

# compute the cross-sectional area of the rod
Expand Down
6 changes: 4 additions & 2 deletions src/jsrm/systems/planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def select_params_for_lambdify_fn(params: Dict[str, Array]) -> List[Array]:
)

compute_stiffness_matrix_for_all_segments_fn = vmap(
compute_planar_stiffness_matrix, in_axes=(0, 0, 0, 0), out_axes=0
compute_planar_stiffness_matrix
)

@jit
Expand Down Expand Up @@ -214,14 +214,16 @@ def stiffness_fn(
Returns:
K: elastic matrix of shape (n_q, n_q) if formulate_in_strain_space is False or (n_xi, n_xi) otherwise
"""
# length of the segments
l = params["l"]
# cross-sectional area and second moment of area
A = jnp.pi * params["r"] ** 2
Ib = A**2 / (4 * jnp.pi)

# elastic and shear modulus
E, G = params["E"], params["G"]
# stiffness matrix of shape (num_segments, 3, 3)
S = compute_stiffness_matrix_for_all_segments_fn(A, Ib, E, G)
S = compute_stiffness_matrix_for_all_segments_fn(l, A, Ib, E, G)
# we define the elastic matrix of shape (n_xi, n_xi) as K(xi) = K @ xi where K is equal to
K = blk_diag(S)

Expand Down
5 changes: 3 additions & 2 deletions src/jsrm/systems/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def compute_strain_basis(


@jit
def compute_planar_stiffness_matrix(A: Array, Ib: Array, E: Array, G: Array) -> Array:
def compute_planar_stiffness_matrix(l: Array, A: Array, Ib: Array, E: Array, G: Array) -> Array:
"""
Compute the stiffness matrix of the system.
Args:
l: length of the segment of shape ()
A: cross-sectional area of shape ()
Ib: second moment of area of shape ()
E: Elastic modulus of shape ()
Expand All @@ -120,6 +121,6 @@ def compute_planar_stiffness_matrix(A: Array, Ib: Array, E: Array, G: Array) ->
Returns:
S: stiffness matrix of shape (3, 3)
"""
S = jnp.diag(jnp.stack([Ib * E, 4 / 3 * A * G, A * E], axis=0))
S = l* jnp.diag(jnp.stack([Ib * E, 4 / 3 * A * G, A * E], axis=0))

return S
4 changes: 2 additions & 2 deletions tests/test_planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def test_planar_pcs_one_segment():
"r": jnp.array([2e-2]),
"rho": 1000 * jnp.ones((1,)),
"g": jnp.array([0.0, -9.81]),
"E": 1e7 * jnp.ones((1,)), # Elastic modulus [Pa]
"G": 1e6 * jnp.ones((1,)), # Shear modulus [Pa]
"E": 1e8 * jnp.ones((1,)), # Elastic modulus [Pa]
"G": 1e7 * jnp.ones((1,)), # Shear modulus [Pa]
}
# activate all strains (i.e. bending, shear, and axial)
strain_selector = jnp.ones((3,), dtype=bool)
Expand Down

0 comments on commit fd05085

Please sign in to comment.