Skip to content

Commit

Permalink
Merge pull request #90 from AIRI-Institute/refactor_0811
Browse files Browse the repository at this point in the history
logging fixes
  • Loading branch information
Vitaly-Protasov authored Nov 8, 2022
2 parents 7bb35d5 + 16ff794 commit a3d3828
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
7 changes: 5 additions & 2 deletions probing/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import logging
import typing
from typing import Dict, List, Optional, Tuple, Union

Expand All @@ -9,10 +8,14 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers.utils import logging

from probing.data_former import EncodedVectorFormer, TokenizedVectorFormer
from probing.utils import exclude_rows

logging.set_verbosity_warning()
logger = logging.get_logger("probing")


class TransformersLoader:
def __init__(
Expand Down Expand Up @@ -207,7 +210,7 @@ def get_tokenized_datasets(
)

if row_ids_to_exclude:
logging.warning(
logger.warning(
f"Since you decided not to truncate long sentences, {len(row_ids_to_exclude)} sample(s) were excluded"
)

Expand Down
7 changes: 5 additions & 2 deletions probing/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import gc
import os
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import trange
from transformers import get_linear_schedule_with_warmup
from transformers.utils import logging

from probing.classifier import MLP, LogReg
from probing.data_former import TextFormer
from probing.encoder import TransformersLoader
from probing.metric import Metric
from probing.utils import ProbingLog, lang_category_extraction, save_log

logging.set_verbosity_warning()
logger = logging.get_logger("probing")


class ProbingPipeline:
def __init__(
Expand Down Expand Up @@ -152,7 +156,6 @@ def run(
self.log_info["params"]["original_classes_ratio"] = task_data.ratio_by_classes

if verbose:
print("=" * 100)
print(
f"Task in progress: {probe_task}\nPath to data: {task_data.data_path}"
)
Expand Down
18 changes: 10 additions & 8 deletions probing/ud_parser/ud_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import csv
import logging
import os
import re
import typing
Expand All @@ -12,13 +11,17 @@
from conllu.models import Token, TokenTree
from nltk.tokenize import wordpunct_tokenize
from sklearn.model_selection import train_test_split
from transformers.utils import logging

from probing.ud_parser.ud_config import (
partitions_by_files,
splits_by_files,
too_much_files_err_str,
)

logging.set_verbosity_warning()
logger = logging.get_logger("probing_parser")


class ConlluUDParser:
def __init__(self, shuffle: bool = True, verbose: bool = True):
Expand Down Expand Up @@ -115,11 +118,11 @@ def check_parts(self, parts: Dict, category: str) -> None:
val_categories_set = set(parts["va"][1])
te_categories_set = set(parts["te"][1])
if tr_categories_set != val_categories_set:
logging.warning(
logger.warning(
f'The classes in train and validation parts are different for category "{category}"'
)
elif val_categories_set != te_categories_set:
logging.warning(
logger.warning(
f'The classes in train and test parts are different for category "{category}"'
)

Expand Down Expand Up @@ -202,10 +205,10 @@ def generate_probing_file(
num_classes = len(classified_sentences.keys())

if num_classes == 1:
logging.warning(f'Category "{category}" has only one class')
logger.warning(f'Category "{category}" has only one class')
return {}
elif num_classes == 0:
logging.warning(
logger.warning(
f'This file does not contain examples of category "{category}"'
)
return {}
Expand All @@ -231,7 +234,7 @@ def generate_probing_file(
parts = {}

if not parts:
logging.warning(
logger.warning(
f'Not enough data of category "{category}" for stratified split'
)
return parts
Expand Down Expand Up @@ -306,7 +309,7 @@ def generate_data_by_categories(

if len(categories) == 0:
paths_str = "\n".join([str(p) for p in paths])
logging.warning(
logger.warning(
f"Something went wrong during processing files. None categories were found for paths:\n{paths_str}"
)

Expand Down Expand Up @@ -384,7 +387,6 @@ def convert(
dir_path: a path to a directory with all files
"""
if self.verbose:
print("=" * 100)
paths_str = "\n".join(
[
str(p)
Expand Down

0 comments on commit a3d3828

Please sign in to comment.