Skip to content

Commit

Permalink
Merged PR 25220: Add extra model information to model_info.py script
Browse files Browse the repository at this point in the history
Adding model shapes flag to model_info.py script: dtype and total number of model parameters.
Example: `python model_info.py -m ~/model.npz -mi`
  • Loading branch information
alexandremuzio authored and Roman Grundkiewicz committed Nov 30, 2022
1 parent b6581c4 commit b7205fc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions scripts/contrib/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ def main():
else:
print(model[args.key])
else:
total_nb_of_parameters = 0
for key in model:
if args.matrix_shapes:
print(key, model[key].shape)
if not key == S2S_SPECIAL_NODE:
total_nb_of_parameters += np.prod(model[key].shape)
if args.matrix_info:
print(key, model[key].shape, model[key].dtype)
else:
print(key)
print('Total number of parameters:', total_nb_of_parameters)


def parse_args():
Expand All @@ -57,8 +61,8 @@ def parse_args():
help="print values from special:model.yml node")
parser.add_argument("-f", "--full-matrix", action="store_true",
help="force numpy to print full arrays for single key")
parser.add_argument("-ms", "--matrix-shapes", action="store_true",
help="print shapes of all arrays in the model")
parser.add_argument("-mi", "--matrix-info", action="store_true",
help="print full matrix info for all keys. Includes shape and dtype")
return parser.parse_args()


Expand Down

0 comments on commit b7205fc

Please sign in to comment.