Skip to content

Commit

Permalink
Merge pull request #6 from Coda-Research-Group/bugfix-#5-create-buckets
Browse files Browse the repository at this point in the history
fix model params not parsed if model_path empty
  • Loading branch information
ProchazkaDavid authored Oct 2, 2024
2 parents 3598184 + 2f41978 commit 34c5756
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions training/create-buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from model import LIDatasetPredict, load_model
from tqdm import tqdm
from utils import (
create_dir,
dir_exists,
Expand Down Expand Up @@ -42,16 +41,23 @@ def load_all_embeddings(path):
def parse_model_params(model_path):
LOG.info(f'Parsing out model params from model path: {model_path}')
pattern = r'model-(\w+)--.*?n_classes-(\d+)(?:--.*?dimensionality-(\d+))?'

if model_path is None:
model = 'MLP'
dimensionality = DEFAULT_DIMENSIONALITY
n_classes = 2
LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}')
return model, dimensionality, n_classes

match = re.search(pattern, model_path, re.MULTILINE)
# new model format
if match and len(match.groups()) == 3:
model = match.group(1)
n_classes = int(match.group(2))
dimensionality = match.group(3)
model, n_classes, dimensionality = match.groups()
dimensionality = int(dimensionality) if dimensionality is not None else DEFAULT_DIMENSIONALITY
n_classes = int(n_classes)
else:
LOG.info(f'Failed to parse out model params from model path: {model_path}')
exit(1)

LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}')
return model, dimensionality, n_classes

Expand Down

0 comments on commit 34c5756

Please sign in to comment.