diff --git a/app/app.py b/app/app.py index 0fc01d9..738564f 100644 --- a/app/app.py +++ b/app/app.py @@ -119,13 +119,13 @@ def predict_bow(): preds = [] data = request.get_json(force=True) texts = data['text'] - print(texts) + # print(texts) preprocessed_text = [preprocess(text, n=2) for text in texts.split('.')] texts_joined = [' '.join(text) for text in preprocessed_text] - print(texts_joined) + # print(texts_joined) vectorized_text = vectorizer.transform(texts_joined) preds = bow_model.predict(vectorized_text) - print(preds) + # print(preds) return jsonify(prediction=preds.tolist(),text=texts.split('.')) return None diff --git a/docs/_posts/0000-01-01-intro.md b/docs/_posts/0000-01-01-intro.md index 5b0f580..5f13ff4 100644 --- a/docs/_posts/0000-01-01-intro.md +++ b/docs/_posts/0000-01-01-intro.md @@ -1,6 +1,8 @@ --- layout: slide -title: "NLP Project" +title: "Using Natural Language Processing to Identify Unfair Clauses in Terms and Conditions Documents" --- -Use the right arrow to begin! +**Authors:** Jonathan Sears, Nick Radwin +**Institution:** Tulane University +**Emails:** jsears1@tuane.edu, nradwin@tulane.edu diff --git a/docs/_posts/0000-01-02-overview.md b/docs/_posts/0000-01-02-overview.md index 88478db..523c3e9 100644 --- a/docs/_posts/0000-01-02-overview.md +++ b/docs/_posts/0000-01-02-overview.md @@ -1,19 +1,9 @@ --- layout: slide -title: "Equations and Tables" +title: "Introduction" --- -Here is an inline equation: $\sum_{i=1}^n i = ?$ +## Introduction -And a block one: - -$$e = mc^2$$ - - -Here is a table: - -| header 1 | header 2 | -|----------|----------| -| value 1 | value 2 | -| value 3 | value 4 | +Despite their ubiquity, terms and conditions are seldom read by users, leading to widespread ignorance about potentially exploitative or unfair clauses. Our project aims to bring these hidden clauses to light by using a sentence level text classifier that labels clauses as either exploitative (1) or non exploitative(0). We based these labels off of categories as outlined in a prior paper we will discuss shortly. diff --git a/docs/_posts/0000-01-03-next.md b/docs/_posts/0000-01-03-next.md index a10b1ad..37394a9 100644 --- a/docs/_posts/0000-01-03-next.md +++ b/docs/_posts/0000-01-03-next.md @@ -1,13 +1,9 @@ --- layout: slide -title: "Images" +title: "Related Work" --- +Our experiments are Primarily based off of **CLAUDETTE** a research project conducted at Stanford in 2018. +They ultimately used an ensemble method, combining SVMs with LSTMs ,and CNNs, to achieve accuracy and f1-scores above .8. This was our target for this project. -Two ways to add an image. - -Note that the image is in the assets/img folder. - - - -![tulane](assets/img/tulane.png) +![claudette](assets/img/claudette.png) diff --git a/docs/_posts/0000-01-04-approach.md b/docs/_posts/0000-01-04-approach.md new file mode 100644 index 0000000..d7262fa --- /dev/null +++ b/docs/_posts/0000-01-04-approach.md @@ -0,0 +1,12 @@ +--- +layout: slide +title: "Approach" +--- + +We employed multiple machine learning approaches to address the challenge of identifying unfair clauses: +- **BERT models:** Utilized for their deep contextual representations. +- **Bag of Words (BoW):** Simplified text representation focusing on term frequencies. +- **Support Vector Machine (SVM):** Tested for its capability to establish a clear decision boundary. +- **Convolutional Neural Network (CNN):** Explored for its pattern recognition capabilities within text data. +- **Gradient Boosting Machine (GBM):** Chosen for its robustness and iterative improvement on classification tasks. +- **Hybrid BERT/BoW model:** An attempt to combine the strengths of BERT and BoW models. diff --git a/docs/_posts/0000-01-04-conclusion.md b/docs/_posts/0000-01-04-conclusion.md deleted file mode 100644 index 76f0dc2..0000000 --- a/docs/_posts/0000-01-04-conclusion.md +++ /dev/null @@ -1,6 +0,0 @@ ---- -layout: slide -title: "Conclusions" ---- - -Hi there diff --git a/docs/_posts/0000-01-05-dataset-and-methodology.md b/docs/_posts/0000-01-05-dataset-and-methodology.md new file mode 100644 index 0000000..2f56da5 --- /dev/null +++ b/docs/_posts/0000-01-05-dataset-and-methodology.md @@ -0,0 +1,7 @@ +--- +layout: slide +title: "Dataset and Metrics" +--- +- **Dataset:** Consisted of 100 labeled terms and conditions documents, each sentence categorized as either fair or one of nine subcategories of unfair. +- **Binary Classification:** Simplified from multiple to two classes (fair and unfair) to address the dataset's imbalance (92% unfair). +- **Evaluation Metrics:** Precision, recall, and F1 score, with models trained on an evenly distributed sample for fairness in performance evaluation. diff --git a/docs/_posts/0000-01-06-Experiments.md b/docs/_posts/0000-01-06-Experiments.md new file mode 100644 index 0000000..d70b1e6 --- /dev/null +++ b/docs/_posts/0000-01-06-Experiments.md @@ -0,0 +1,6 @@ +--- +layout: slide +title: "Experiments" +--- +We originally experimented with the more complex BERT representation of the text. The thinking behind this was that the BERT encodings would be able to capture a better understanding of the text both semantically and contextually. We experimented with many different methods of fine tuning BERT, attempting to fine tune a single classifier layer on to of pooled +However we were unable to produce results near that of claudette, with our best variants of the fine tuned BERT model unable to crack an f1-score of .6 \ No newline at end of file diff --git a/docs/_posts/0000-01-07-Experiments-hybrid.md b/docs/_posts/0000-01-07-Experiments-hybrid.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/_posts/0000-01-08-Experiments-bow.md b/docs/_posts/0000-01-08-Experiments-bow.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/assets/img/claudette.png b/docs/assets/img/claudette.png new file mode 100644 index 0000000..5d0acb7 Binary files /dev/null and b/docs/assets/img/claudette.png differ diff --git a/notebooks/Experiments.ipynb b/notebooks/Experiments.ipynb index eb415a9..8eeafcc 100644 --- a/notebooks/Experiments.ipynb +++ b/notebooks/Experiments.ipynb @@ -11,9 +11,19 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-30 15:09:53.241092: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-04-30 15:09:56.280105: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] + } + ], "source": [ "#imports\n", "import pandas as pd\n", @@ -50,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -233,7 +243,7 @@ "4 NaN NaN NaN NaN " ] }, - "execution_count": 28, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -259,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -279,7 +289,7 @@ "Name: proportion, dtype: float64" ] }, - "execution_count": 29, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -293,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -311,7 +321,7 @@ " 9: 'USE'}" ] }, - "execution_count": 30, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -349,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -419,7 +429,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -624,7 +634,7 @@ "[5 rows x 21 columns]" ] }, - "execution_count": 32, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -645,7 +655,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -654,7 +664,7 @@ "'ubisoft advis includ surnam user name ubisoft_advis advis_includ includ_surnam surnam_user user_name'" ] }, - "execution_count": 33, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -665,7 +675,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -674,7 +684,7 @@ "'websites & communications terms of use'" ] }, - "execution_count": 34, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -685,7 +695,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -702,7 +712,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -711,7 +721,7 @@ "array(['0', '0', '0', ..., '1', '0', '0'], dtype=object)" ] }, - "execution_count": 36, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -722,7 +732,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -731,7 +741,7 @@ "array(['0'], dtype=object)" ] }, - "execution_count": 37, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -749,7 +759,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -762,7 +772,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -836,7 +846,7 @@ "1090 regardless manner arbitr conduct arbitr shall ... " ] }, - "execution_count": 39, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -847,7 +857,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -861,7 +871,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -885,7 +895,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -906,13 +916,13 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "543730edf3b548eda90e9d07ebaa5e82", + "model_id": "0f573db3b9eb40a9a68c3a03eab3a65a", "version_major": 2, "version_minor": 0 }, @@ -926,7 +936,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fe45c9f6551048be9f84beacc2af8e5c", + "model_id": "28b03324a3294e33820e35dc4740b44a", "version_major": 2, "version_minor": 0 }, @@ -940,7 +950,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d8288f6a08124db9926f3c667b80f597", + "model_id": "876daca7a0df4b448178f0f46dd3f044", "version_major": 2, "version_minor": 0 }, @@ -954,7 +964,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "dfe5241ef76748f8b12e815ef2e24df7", + "model_id": "8497c2f42ecd4cb9bf495c98dbd2519e", "version_major": 2, "version_minor": 0 }, @@ -968,7 +978,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e11bdc14ae9a48138847b00602395045", + "model_id": "58cb8fe3c4e04ed3b9127cc7164d0a6e", "version_major": 2, "version_minor": 0 }, @@ -994,7 +1004,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -1021,7 +1031,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -1053,7 +1063,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -1062,15 +1072,15 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "they also do not apply to membership of skype developer .\n", - "also appli membership skype develop also_appli appli_membership membership_skype skype_develop\n" + "for any action at law or in equity relating to the arbitration provision of these terms of use , the excluded disputes or if you opt out of the agreement to arbitrate , you agree to resolve any dispute you have with instagram exclusively in a state or federal court located in santa clara , california , and to submit to the personal jurisdiction of the courts located in santa clara county for the purpose of litigating all such disputes .\n", + "action law equiti relat arbitr provis term use exclud disput opt agreement arbitr agre resolv disput instagram exclus state feder court locat santa clara california submit person jurisdict court locat santa clara counti purpos litig disput action_law law_equiti equiti_relat relat_arbitr arbitr_provis provis_term term_use use_exclud exclud_disput disput_opt opt_agreement agreement_arbitr arbitr_agre agre_resolv resolv_disput disput_instagram instagram_exclus exclus_state state_feder feder_court court_locat locat_santa santa_clara clara_california california_submit submit_person person_jurisdict jurisdict_court court_locat locat_santa santa_clara clara_counti counti_purpos purpos_litig litig_disput\n" ] } ], @@ -1090,7 +1100,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -1139,7 +1149,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -1182,7 +1192,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1243,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1265,7 +1275,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -1274,7 +1284,7 @@ "tensor([ 1, 0, 12, 4, 124, 44, 2])" ] }, - "execution_count": 52, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -1285,7 +1295,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -1340,7 +1350,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 68, "metadata": {}, "outputs": [ { @@ -1350,56 +1360,84 @@ "You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.\n", "Some weights of BertModel were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.output.LayerNorm.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.dense.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", - "100%|██████████| 300/300 [01:08<00:00, 4.36it/s]\n" + "100%|██████████| 300/300 [00:55<00:00, 5.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Epoch: 1 | Accuracy: 0.44 | Precision: 0.48175182481751827 | Recall: 0.8354430379746836 | F1: 0.6111111111111112 | Loss: 0.8460731989145279\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:25<00:00, 3.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Epoch: 1 | Accuracy: 0.13 | Precision: 0.13131313131313133 | Recall: 0.9285714285714286 | F1: 0.23008849557522124 | Loss: 0.969728703200817\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 300/300 [01:54<00:00, 2.63it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Train Epoch: 1 | Accuracy: 0.5566666666666666 | Precision: 0.5566666666666666 | Recall: 1.0 | F1: 0.715203426124197 | Loss: 0.06296422884643423\n" + "Train Epoch: 2 | Accuracy: 0.4533333333333333 | Precision: 0.4892086330935252 | Recall: 0.8607594936708861 | F1: 0.6238532110091743 | Loss: 0.8457227943340937\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 100/100 [00:25<00:00, 3.97it/s]\n" + "100%|██████████| 100/100 [00:30<00:00, 3.28it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Test Epoch: 1 | Accuracy: 0.11 | Precision: 0.11 | Recall: 1.0 | F1: 0.1981981981981982 | Loss: 0.05178694828366907\n" + "Test Epoch: 2 | Accuracy: 0.13 | Precision: 0.13131313131313133 | Recall: 0.9285714285714286 | F1: 0.23008849557522124 | Loss: 0.969214705824852\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 300/300 [01:08<00:00, 4.39it/s]\n" + "100%|██████████| 300/300 [01:18<00:00, 3.80it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Train Epoch: 2 | Accuracy: 0.5566666666666666 | Precision: 0.5566666666666666 | Recall: 1.0 | F1: 0.715203426124197 | Loss: 0.06295500936018938\n" + "Train Epoch: 3 | Accuracy: 0.42 | Precision: 0.4701492537313433 | Recall: 0.7974683544303798 | F1: 0.5915492957746479 | Loss: 0.8475727983315786\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 100/100 [00:26<00:00, 3.79it/s]" + "100%|██████████| 100/100 [00:21<00:00, 4.75it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Test Epoch: 2 | Accuracy: 0.11 | Precision: 0.11 | Recall: 1.0 | F1: 0.1981981981981982 | Loss: 0.05175898351818761\n" + "Test Epoch: 3 | Accuracy: 0.13 | Precision: 0.13131313131313133 | Recall: 0.9285714285714286 | F1: 0.23008849557522124 | Loss: 0.968689352273941\n" ] }, { @@ -1415,13 +1453,13 @@ "hybrid_model.to(device)\n", "hybrid_model.train()\n", "loss_fn = nn.BCELoss()\n", - "optimizer = Adam(hybrid_model.parameters(),lr =.00001)\n", - "train_hybrid(hybrid_model,tokenized_dict[\"train_subset\"],tokenized_dict[\"test_subset\"],2,loss_fn,optimizer,3460,2042,bow,testing=True)" + "optimizer = Adam(hybrid_model.parameters(),lr =.000001)\n", + "train_hybrid(hybrid_model,tokenized_dict[\"train\"],tokenized_dict[\"test\"],5,loss_fn,optimizer,3460,2042,bow,testing=True)" ] }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [