diff --git a/nlp/app/forms.py b/nlp/app/forms.py index 9276357..4dbe0c3 100644 --- a/nlp/app/forms.py +++ b/nlp/app/forms.py @@ -17,11 +17,57 @@ class Meta: # Ignoring CSRF security feature. ) text_options = SelectField(label='Select Option:', - choices=[ - ('option1', 'Option 1 Text'), - ('option2', 'Option 2 Text'), - ('option3', 'Option 3 Text'), - ], + choices= [ + ('./nlp/data/vectors/Confident_Insecure.pt', 'Insecure - Confident'), + ('./nlp/data/vectors/Evasive_Direct.pt', 'Direct - Evasive'), + ('./nlp/data/vectors/Diplomatic_Blunt.pt', 'Blunt - Diplomatic'), + ('./nlp/data/vectors/Inquisitive_Disinterested.pt', 'Disinterested - Inquisitive'), + ('./nlp/data/vectors/Thoughtful_Impulsive.pt', 'Impulsive - Thoughtful'), + ('./nlp/data/vectors/Nonchalant_Concerned.pt', 'Concerned - Nonchalant'), + ('./nlp/data/vectors/Flippant_Serious.pt', 'Serious - Flippant'), + ('./nlp/data/vectors/Precise_Vague.pt', 'Vague - Precise'), + ('./nlp/data/vectors/Rambling_Concise.pt', 'Concise - Rambling'), + ('./nlp/data/vectors/Analytical_Intuitive.pt', 'Intuitive - Analytical'), + ('./nlp/data/vectors/Assertive_Passive.pt', 'Passive - Assertive'), + ('./nlp/data/vectors/Considerate_Thoughtless.pt', 'Thoughtless - Considerate'), + ('./nlp/data/vectors/Elusive_Clear.pt', 'Clear - Elusive'), + ('./nlp/data/vectors/Candid_Guarded.pt', 'Guarded - Candid'), + ('./nlp/data/vectors/Defensive_Open.pt', 'Open - Defensive'), + ('./nlp/data/vectors/Engaging_Detached.pt', 'Detached - Engaging'), + ('./nlp/data/vectors/Reserved_Outgoing.pt', 'Outgoing - Reserved'), + ('./nlp/data/vectors/Empathetic_Unsympathetic.pt', 'Unsympathetic - Empathetic'), + ('./nlp/data/vectors/Concise_Lengthy.pt', 'Lengthy - Concise'), + ('./nlp/data/vectors/Enthusiastic_Apathetic.pt', 'Apathetic - Enthusiastic'), + ('./nlp/data/vectors/Introspective_Extrospective.pt', 'Extrospective - Introspective'), + ('./nlp/data/vectors/Polite_Rude.pt', 'Rude - Polite'), + ('./nlp/data/vectors/Indecisive_Decisive.pt', 'Decisive - Indecisive'), + ('./nlp/data/vectors/Dismissive_Receptive.pt', 'Receptive - Dismissive'), + ('./nlp/data/vectors/Deliberate_Hasty.pt', 'Hasty - Deliberate'), + ('./nlp/data/vectors/Informative_Misleading.pt', 'Misleading - Informative'), + ('./nlp/data/vectors/Focused_Distracted.pt', 'Distracted - Focused'), + ('./nlp/data/vectors/Perplexed_Clear.pt', 'Clear - Perplexed'), + ('./nlp/data/vectors/Cooperative_Uncooperative.pt', 'Uncooperative - Cooperative'), + ('./nlp/data/vectors/Inattentive_Attentive.pt', 'Attentive - Inattentive'), + ('./nlp/data/vectors/Contemplative_Shallow.pt', 'Shallow - Contemplative'), + ('./nlp/data/vectors/Evocative_Uninspiring.pt', 'Uninspiring - Evocative'), + ('./nlp/data/vectors/Witty_Dull.pt', 'Dull - Witty'), + ('./nlp/data/vectors/Succinct_Rambling.pt', 'Rambling - Succinct'), + ('./nlp/data/vectors/Arrogant_Humble.pt', 'Humble - Arrogant'), + ('./nlp/data/vectors/Measured_Impulsive.pt', 'Impulsive - Measured'), + ('./nlp/data/vectors/Elaborate_Simple.pt', 'Simple - Elaborate'), + ('./nlp/data/vectors/Unresponsive_Responsive.pt', 'Responsive - Unresponsive'), + ('./nlp/data/vectors/Courteous_Rude.pt', 'Rude - Courteous'), + ('./nlp/data/vectors/Tentative_Definite.pt', 'Definite - Tentative'), + ('./nlp/data/vectors/Compelling_Unconvincing.pt', 'Unconvincing - Compelling'), + ('./nlp/data/vectors/Casual_Formal.pt', 'Formal - Casual'), + ('./nlp/data/vectors/Insightful_Superficial.pt', 'Superficial - Insightful'), + ('./nlp/data/vectors/Assertive_Passive_2.pt', 'Passive - Assertive'), + ('./nlp/data/vectors/Outgoing_Introverted.pt', 'Introverted - Outgoing'), + ('./nlp/data/vectors/Concise_Wordy.pt', 'Wordy - Concise'), + ('./nlp/data/vectors/Confident_Timid.pt', 'Timid - Confident'), + ('./nlp/data/vectors/Polite_Impolite.pt', 'Impolite - Polite'), + ('./nlp/data/vectors/Engaging_Aloof.pt', 'Aloof - Engaging') +], validators=[DataRequired()], render_kw={'class': 'text-options-dropdown'} ) diff --git a/nlp/app/routes.py b/nlp/app/routes.py index c557b26..253406b 100644 --- a/nlp/app/routes.py +++ b/nlp/app/routes.py @@ -29,6 +29,8 @@ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) model = model.to("cuda:0" if torch.cuda.is_available() else "cpu") +model = ControlModel(model, list(range(-5, -18, -1))) +user_tag, asst_tag = "[INST]", "[/INST]" @@ -44,7 +46,37 @@ def index(): form = MyForm() result = None if form.validate_on_submit(): + vector_path = form.text_options.data + + control_vector = torch.load(vector_path) + input_field = form.input_field.data + + input_query = user_tag + input_field + asst_tag + + print("magnitude: %s" % form.magnitude.data) + + mag = float(form.magnitude.data) + + input_ids = tokenizer(input_query, return_tensors="pt").to(model.device) + settings = { + "pad_token_id": tokenizer.eos_token_id, # silence warning + "do_sample": False, # temperature=0 + "max_new_tokens": 512, + "repetition_penalty": 1.1, # reduce control jank + } + + print("==baseline") + model.reset() + default_output = str(tokenizer.decode(model.generate(**input_ids, **settings).squeeze())) + + + print("\n++control") + # add the control vector with a certain strength (try increasing or decreasing this!) + model.set_control(control_vector, mag) + control_output = str(tokenizer.decode(model.generate(**input_ids, **settings).squeeze())) + + model.reset() #X = vec.transform([input_field]) #pred = clf.predict(X)[0] pred = "PRED" @@ -52,6 +84,6 @@ def index(): proba = 0.5 # flash(input_field) return render_template('myform.html', title='', form=form, - prediction=labels[pred], confidence='%.2f' % proba) + prediction=control_output, confidence='%.2f' % proba) #return redirect('/index') return render_template('myform.html', title='', form=form, prediction=None, confidence=None)