Skip to content

Commit

Permalink
better cache paths
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinjdobler committed Nov 6, 2023
1 parent bac24a6 commit 6c5740f
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/deepfocus/fasttext_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
CACHE_DIR = (Path(os.getenv("XDG_CACHE_HOME", "~/.cache")) / "deepfocus").expanduser().resolve()


def sanitize_path(path: str):
return Path(path).as_posix().replace("/", "_")


def train_fasttext(
text_path: str,
target_tokenizer: PreTrainedTokenizer,
Expand Down Expand Up @@ -43,7 +47,7 @@ def train_fasttext(
target_tokenizer_hash = Hasher().hash(target_tokenizer)
data_file_suffix = Path(text_path).suffix

text_path_sanitized = text_path.rstrip("/\\").replace("/", "_").replace("\\", "_")
text_path_sanitized = sanitize_path(text_path)

cache_file = CACHE_DIR / "data" / f"{text_path_sanitized}_tokenized_{target_tokenizer_hash}{data_file_suffix}"

Expand Down Expand Up @@ -73,9 +77,11 @@ def train_fasttext(
temp_file.write(text + "\n")
cache_file = temp_file.name

logger.info(f"Training fasttext model on {f'tokenized {text_path}' if cache_tokenized_text else cache_file} for {epochs} epochs with {dim=}...")
logger.info(
f"Training fasttext model on {f'tokenized {text_path}' if cache_tokenized_text else cache_file} for {epochs} epochs with {dim=}..."
)
# We use CBOW instead of skipgram because CBOW is more closely aligned with Masked Language Modeling
# minCount to filter out spurious tokens that wil not get a good fasttext embedding
# minCount to filter out spurious tokens that will not get a good fasttext embedding
return fasttext.train_unsupervised(
str(cache_file),
dim=dim,
Expand All @@ -100,8 +106,8 @@ def download_pretrained_fasttext_word_embs(identifier: str, verbose=True):
if os.path.exists(identifier):
path = Path(identifier)
else:
logger.debug(
f"Identifier '{identifier}' does not seem to be a path (file does not exist). Interpreting as language code."
logger.info(
f"Loading fasttext *word* embeddings for language '{identifier}' from https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{identifier}.300.bin.gz."
)

path = CACHE_DIR / "pretrained_fasttext" / f"cc.{identifier}.300.bin"
Expand Down Expand Up @@ -173,7 +179,7 @@ def train_or_load_fasttext_model(
):
target_tokenizer_hash = Hasher().hash(target_tokenizer)

text_path_sanitized = text_path.rstrip("/\\").replace("/", "_").replace("\\", "_")
text_path_sanitized = sanitize_path(text_path)

model_cache_path = Path(
model_cache_path
Expand Down

0 comments on commit 6c5740f

Please sign in to comment.