Skip to content

Commit

Permalink
Add progress bar when feature computation takes >5s
Browse files Browse the repository at this point in the history
  • Loading branch information
yonromai committed Oct 10, 2023
1 parent 17e72af commit 6f83df7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
1 change: 1 addition & 0 deletions nxontology_ml/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def predict(
take: int | None = 0,
) -> pd.DataFrame:
target_nodes: list[str] = list(get_disease_nodes(take=take, nxo=nxo))
assert len(target_nodes) > 0, "No disease node found"
target_features = feature_pipeline.transform(target_nodes)
target_labels = model.predict(target_features)
target_probas = model.predict_proba(target_features)
Expand Down
21 changes: 19 additions & 2 deletions nxontology_ml/sklearn_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from nxontology.node import NodeInfo
from pandas.core.dtypes.base import ExtensionDtype
from sklearn.base import TransformerMixin
from tqdm import tqdm


@dataclass
Expand Down Expand Up @@ -73,16 +74,32 @@ def transform(self, X: NodeFeatures, copy: bool | None = None) -> NodeFeatures:
return X
if self._num_features_fn:
assert self._num_features_names

vecs: list[np.array] = []
for node in tqdm(
X.nodes,
desc=f"{self.__class__.__name__}: Computing num features",
delay=5,
):
vecs.append(self._num_features_fn(node))

new_features = pd.DataFrame(
data=[self._num_features_fn(node) for node in X.nodes],
data=vecs,
columns=self._num_features_names,
dtype=self._num_feature_dtype,
)
X.num_features = pd.concat([X.num_features, new_features], axis=1)
if self._cat_features_fn:
assert self._cat_features_names
cat_vecs: list[np.array] = []
for node in tqdm(
X.nodes,
desc=f"{self.__class__.__name__}: Computing cat features",
delay=5,
):
cat_vecs.append(self._cat_features_fn(node))
new_features = pd.DataFrame(
data=[self._cat_features_fn(node) for node in X.nodes],
data=cat_vecs,
columns=self._cat_features_names,
)
X.cat_features = pd.concat([X.cat_features, new_features], axis=1)
Expand Down
11 changes: 10 additions & 1 deletion nxontology_ml/text_embeddings/text_embeddings_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sklearn.base import TransformerMixin
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from tqdm import tqdm

from nxontology_ml.sklearn_transformer import (
NodeFeatures,
Expand Down Expand Up @@ -77,7 +78,15 @@ def transform(self, X: NodeFeatures, copy: bool | None = None) -> NodeFeatures:
return X

def _nodes_to_vec(self, X: NodeFeatures) -> np.ndarray:
return np.array([self._embedding_model.embed_node(node) for node in X.nodes])
embedded_nodes: list[np.array] = []
for node in tqdm(
X.nodes,
desc="Fetching node embeddings",
delay=5,
):
embedded_nodes.append(self._embedding_model.embed_node(node))

return np.array(embedded_nodes)

@classmethod
def from_config(
Expand Down

0 comments on commit 6f83df7

Please sign in to comment.