diff --git a/metro/modeling/_smpl.py b/metro/modeling/_smpl.py index 92c90ca..bbfed17 100644 --- a/metro/modeling/_smpl.py +++ b/metro/modeling/_smpl.py @@ -132,7 +132,7 @@ def get_h36m_joints(self, vertices): Input: vertices: size = (B, 6890, 3) Output: - 3D joints: size = (B, 24, 3) + 3D joints: size = (B, 17, 3) """ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct]) return joints