Skip to content

Commit

Permalink
0.68.2
Browse files Browse the repository at this point in the history
  • Loading branch information
FBurkhardt committed Nov 9, 2023
1 parent fa9fe66 commit dbb1ea1
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Version 0.68.2
--------------
* column names in datasets are now configurable

Version 0.68.1
--------------
* added error message on file to praat extraction
Expand Down
2 changes: 1 addition & 1 deletion nkululeko/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION="0.68.1"
VERSION="0.68.2"
SAMPLING_RATE = 16000
10 changes: 8 additions & 2 deletions nkululeko/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def _load_db(self):
self.util.error(f"{self.name}: no database found at {root}")
return root

def _check_cols(self, df):
rename_cols = self.util.config_val_data(self.name, "colnames", False)
if rename_cols:
col_dict = ast.literal_eval(rename_cols)
df = df.rename(columns=col_dict)
return df

def _report_load(self):
speaker_num = 0
if self.got_speaker:
Expand Down Expand Up @@ -205,6 +212,7 @@ def _get_df_for_lists(self, db, df_files):
df = pd.DataFrame()
for table in df_files:
source_df = db.tables[table].df
source_df = self._check_cols(source_df)
# create a dataframe with the index (the filenames)
df_local = pd.DataFrame(index=source_df.index)
# try to get the targets from this dataframe
Expand All @@ -224,8 +232,6 @@ def _get_df_for_lists(self, db, df_files):
# try to get the gender values
if "gender" in source_df:
df_local["gender"] = source_df["gender"]
else:
df_local["gender"] = source_df["sex"]
got_gender = True
except (KeyError, ValueError, audformat.errors.BadKeyError) as e:
pass
Expand Down
13 changes: 7 additions & 6 deletions nkululeko/data/dataset_csv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# dataset_csv.py
import ast
import os
import os.path
import pandas as pd
Expand All @@ -22,6 +23,10 @@ def load(self):
root = os.path.dirname(data_file)
audio_path = self.util.config_val_data(self.name, "audio_path", "")
df = audformat.utils.read_csv(data_file)
rename_cols = self.util.config_val_data(self.name, "colnames", False)
if rename_cols:
col_dict = ast.literal_eval(rename_cols)
df = df.rename(columns=col_dict)
absolute_path = eval(
self.util.config_val_data(self.name, "absolute_path", True)
)
Expand All @@ -42,20 +47,16 @@ def load(self):
lambda x: root + "/" + audio_path + "/" + x
)
)

self.df = df
self.db = None
self.got_target = True
self.is_labeled = self.got_target
self.start_fresh = eval(
self.util.config_val("DATA", "no_reuse", "False")
)
self.start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
if self.is_labeled and not "class_label" in self.df.columns:
self.df["class_label"] = self.df[self.target]
if "gender" in self.df.columns:
self.got_gender = True
elif "sex" in self.df.columns:
self.df = self.df.rename(columns={"sex": "gender"})
self.got_gender = True
if "age" in self.df.columns:
self.got_age = True
if "speaker" in self.df.columns:
Expand Down

0 comments on commit dbb1ea1

Please sign in to comment.