Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge NLP module fixes/improvements #120

Merged
merged 17 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
585 changes: 585 additions & 0 deletions nlp/desci_sense/evaluation/Evaluation_benchmark.py

Large diffs are not rendered by default.

69 changes: 69 additions & 0 deletions nlp/desci_sense/evaluation/item_type_stat_pie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from datetime import datetime
import wandb
from pathlib import Path
import pandas as pd
import numpy as np
import sys

sys.path.append(str(Path(__file__).parents[2]))

from desci_sense.evaluation.Evaluation_benchmark import TwitterEval
from desci_sense.evaluation.utils import obj_str_to_dict, get_dataset

if __name__ == "__main__":

wandb.login()

api = wandb.Api()

#TODO move from testing
run = wandb.init(project="testing", job_type="evaluation")

# get artifact path

dataset_artifact_id = (
'common-sense-makers/filter_evaluation/prediction_evaluation-20240521132713:v0'
)

# set artifact as input artifact
dataset_artifact = run.use_artifact(dataset_artifact_id)

# initialize table path
# add the option to call table_path = arguments.get('--dataset')

# download path to table
a_path = dataset_artifact.download()
print("The path is",a_path)

# get dataset file name

table_path = Path(f"{a_path}/prediction_evaluation.table.json")


# return the pd df from the table
#remember to remove the head TODO
df = get_dataset(table_path)

dataset_run = dataset_artifact.logged_by()

config = dataset_run.config

Eval = TwitterEval(config=config)



fig1, fig2 = Eval.build_item_type_pie(df=df)

wandb.log({"item_type_distribution": wandb.Image(fig1)})

wandb.log({"allowlist_item_type_distribution": wandb.Image(fig2)})

true_df = df[df["True Label"] == 'research']

fig1 , fig2 = Eval.build_item_type_pie(true_df)
wandb.log({"research_type_distribution": wandb.Image(fig1)})
config = obj_str_to_dict(config)

run.config.update(config)

wandb.run.finish()
20 changes: 19 additions & 1 deletion nlp/desci_sense/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import pandas as pd
import numpy as np
from collections import Counter
import concurrent.futures
from tqdm import tqdm
from sklearn.preprocessing import MultiLabelBinarizer
Expand Down Expand Up @@ -132,4 +133,21 @@ def create_custom_confusion_matrix(y_true, y_pred, labels):
fp_j = ~y_true[:, j] & y_pred[:, j]
matrix[i, j] = np.sum(fn_i & fp_j)

return pd.DataFrame(matrix, index=labels, columns=labels)
return pd.DataFrame(matrix, index=labels, columns=labels)



def autopct_format(pct, total_counts):
total = sum(total_counts)
count = int(round(pct * total / 100.0))
return f'{pct:.1f}% ({count})'

def projection_to_list(list2):
def project_to_list(list1):
#return list(set(list1) & set(list2))
return [item for item in list1 if item in list2]
return project_to_list

def flatten_list(lis:list):
return [item for sublist in lis for item in sublist]

Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ In this doc I specify what triplets are to be present in our app auto-published
sub:provenance {
#Worked with Tobias on a more rebust prov, TODO
cosmo: a prov:SoftwareAgent ;
rdfs:label "research_filter_v1" ;
prov:actedOnBehalfOf x:xHandle .
sub:activity a cosmo:nlpFacilitatedActivity ;
prov:wasAssociatedWith cosmo:.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


## Template Schema

```
sub:pubinfo {
x:xHandle foaf:name "{retractos name}" .
Expand All @@ -34,4 +34,5 @@ sub:pubinfo {
rdfs:label "CoSMO Semantic Post".
this: cosmo:hasRootSinger "{eth address}"
}
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import requests
from datetime import datetime


from ...interface import AppPost, PlatformType
from ...utils import (
extract_and_expand_urls,
normalize_url,
extract_twitter_status_id,
remove_dups_ordered,
normalize_tweet_url,
)
from ...schema.post import RefPost, QuoteRefPost

Expand Down Expand Up @@ -189,22 +191,6 @@ def extract_status_id(url):
return None


def normalize_tweet_url(url):
"""
Normalize Twitter post URLs to use the x.com domain.

Parameters:
url (str): The original Twitter URL.

Returns:
str: The normalized URL with x.com domain.
"""
if "twitter.com" in url:
return url.replace("twitter.com", "x.com")
else:
return url


# TODO combine with method below
def extract_external_ref_urls(tweet: dict, add_qrt_url: bool = True):
"""
Expand Down
99 changes: 47 additions & 52 deletions nlp/desci_sense/shared_functions/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rdflib import URIRef, Literal, Graph
from .prompting.jinja.topics_template import ALLOWED_TOPICS
from .filters import SciFilterClassfication
from .utils import normalize_tweet_urls_in_text, normalize_tweet_url

# for calculating thread length limits
MAX_CHARS_PER_POST = 280
Expand Down Expand Up @@ -91,6 +92,39 @@ class TopicConceptDefinition(OntologyConceptDefinition):
)


class ZoteroItemTypeDefinition(OntologyConceptDefinition):
"""
Definition of the ZoteroItemType predicate which is used to represent a reference's
item type according to the Zotero ontology.
https://www.zotero.org/support/kb/item_types_and_fields
"""

name: str = Field(default="zoteroItemType", description="Concept name.")
uri: str = Field(
default="https://sense-nets.xyz/hasZoteroItemType",
description="Linked data URI for this concept.",
)
versions: List[str] = Field(
["v0"], description="Which ontology versions is this item included in."
)


class QuotedPostDefinition(OntologyConceptDefinition):
"""
Definition of quotedPost relation for a post that quotes another post
https://github.com/Common-SenseMakers/sensemakers/blob/nlp-dev/nlp/desci_sense/schema/Nanopub_schema/semantic_post_quote_schema.md
"""

name: str = Field(default="zoteroItemType", description="Concept name.")
uri: str = Field(
default="https://sense-nets.xyz/quotesPost",
description="Linked data URI for this concept.",
)
versions: List[str] = Field(
["v0"], description="Which ontology versions is this item included in."
)


class isAConceptDefintion(OntologyConceptDefinition):
name: str = Field(default="isA", description="Concept name.")
uri: str = Field(
Expand Down Expand Up @@ -141,24 +175,6 @@ class OntologyInterface(BaseModel):
ontology_config: NotionOntologyConfig = Field(default_factory=NotionOntologyConfig)


# TODO remove - changed to OntologyPredicateDefinition
class OntologyItem(TypedDict):
URI: str
display_name: str
Name: Optional[str]
label: Optional[str]
prompt: str
notes: Optional[str]
valid_subject_types: Optional[str]
valid_object_types: Optional[str]
versions: Optional[str]


# TODO remove - changed to KeywordPredicateDefinition
class KeywordsSupport(TypedDict):
keyWordsOntology: OntologyItem


class RefMetadata(BaseModel):
"""
Schema representing extracted metadata of reference URLs
Expand Down Expand Up @@ -248,7 +264,7 @@ def graph_serializer(graph: Graph):

@field_validator(
"semantics", mode="before"
) # before needed since arbitrary types allowec
) # before needed since arbitrary types allowed
@classmethod
def ensure_graph(cls, value: Any):
if isinstance(value, Graph):
Expand All @@ -271,17 +287,6 @@ def lower_case_platform_id(cls, v):
return v.lower() if isinstance(v, str) else v


# class AppPostContent(BaseModel):


# class AppPost(BaseModel):
# content: str = Field(description="Post content")
# url: Optional[str] = Field(description="Post url", default=None)
# quoted_thread_url: Optional[str] = Field(
# description="Url of quoted thread", default=None
# )


class AppPost(BaseModel):
content: str = Field(description="Post content")
url: Optional[str] = Field(description="Post url", default="")
Expand All @@ -290,6 +295,14 @@ class AppPost(BaseModel):
default=None,
)

@validator("content", pre=True, always=True)
def normalize_twitter_urls(cls, v):
return normalize_tweet_urls_in_text(v) if isinstance(v, str) else v

@validator("url", pre=True, always=True)
def normalize_twitter_url(cls, v):
return normalize_tweet_url(v) if isinstance(v, str) else v


class AppThread(BaseModel):
author: Author
Expand All @@ -299,6 +312,10 @@ class AppThread(BaseModel):
default=None,
)

@validator("url", pre=True, always=True)
def normalize_twitter_url(cls, v):
return normalize_tweet_url(v) if isinstance(v, str) else v

@property
def source_network(self) -> PlatformType:
return self.author.platformId
Expand All @@ -314,25 +331,3 @@ class ParsePostRequest(BaseModel):
description="Additional params for parser (not used currently)",
default_factory=dict,
)


# TODO remove - changed to RefMetadata
class RefMeta(TypedDict):
title: str
description: str
image: str


class ReflabelsSupport(TypedDict):
labelsOntology: List[OntologyItem]
refsMeta: Dict[str, RefMeta]


class ParsedSupport(TypedDict):
keywords: KeywordsSupport
refLabels: ReflabelsSupport


class ParserResultDto(TypedDict):
semantics: str
support: ParsedSupport
Loading