diff --git a/cij/misc/evec_disp2eig.py b/cij/misc/evec_disp2eig.py index 4b0f0a2..361500d 100644 --- a/cij/misc/evec_disp2eig.py +++ b/cij/misc/evec_disp2eig.py @@ -17,13 +17,10 @@ def evec_disp2eig(a: numpy.ndarray, mass: list) -> numpy.ndarray: This function returns normalized :math:`\\left|C_i^m\\right>`. - :param a: 3N x 1 displacement vector or 3N x 3N displacement vector matrix + :param a: M x 3N displacement vector matrix :param mass: N x 1 atom mass vector, of any unit - :returns: Eigenvector matrix, shape determined by the shape of ``a``: - - * If ``a`` is a displacement vector, return a 3N x 1 eigenvector - * If ``a`` is a displacement matrix, return a 3N x 3N eigenvector matrix + :returns: Eigenvector matrix, shape is same as ``a`` (M x 3N): ''' N = len(mass) # number of atoms @@ -31,18 +28,7 @@ def evec_disp2eig(a: numpy.ndarray, mass: list) -> numpy.ndarray: m = numpy.repeat(mass, 3) # Account for 3 polarization directions, N -> 3N a = numpy.copy(a) - if a.shape == (3*N,): - - # Multiply by mass matrix - - a *= numpy.sqrt(m) - - # Renormalization - - norm = numpy.conj(a) @ a.T - a /= numpy.sqrt(norm) - - elif a.shape == (3*N, 3*N): + if a.shape[1] == 3*N: # Multiply by mass matrix diff --git a/tests/test_cij_misc_evec_load.py b/tests/test_cij_misc_evec_load.py index 27e6ea8..451e36a 100644 --- a/tests/test_cij_misc_evec_load.py +++ b/tests/test_cij_misc_evec_load.py @@ -52,7 +52,7 @@ def mass(): return mass @pytest.mark.parametrize("nq, nbnd", [(2, 60)]) -def test_evec_dispvec_disp2eig_vector_version(nq, nbnd, mass): +def test_evec_dispvec_disp2eig_3Nx3N(nq, nbnd, mass): vecs = evec_load(Path(__file__).parent / "data" / test_files["vec"], nq, nbnd) eigs = evec_load(Path(__file__).parent / "data" / test_files["eig"], nq, nbnd) @@ -70,7 +70,7 @@ def test_evec_dispvec_disp2eig_vector_version(nq, nbnd, mass): @pytest.mark.parametrize("nq, nbnd", [(2, 60)]) -def test_evec_dispvec_disp2eig_matrix_version(nq, nbnd, mass): +def test_evec_dispvec_disp2eig_1x3N(nq, nbnd, mass): vecs = evec_load(Path(__file__).parent / "data" / test_files["vec"], nq, nbnd) eigs = evec_load(Path(__file__).parent / "data" / test_files["eig"], nq, nbnd) @@ -85,8 +85,8 @@ def test_evec_dispvec_disp2eig_matrix_version(nq, nbnd, mass): for ibnd in range(nbnd): - a1 = numpy.array(vecs[iq][1][ibnd][1]) - a2 = numpy.array(eigs[iq][1][ibnd][1]) + a1 = numpy.array([vecs[iq][1][ibnd][1]]) + a2 = numpy.array([eigs[iq][1][ibnd][1]]) a1 = evec_disp2eig(a1, mass) - assert numpy.allclose(a1, a2, atol=1e-4) + assert numpy.allclose(a1[0], a2, atol=1e-4)