From 4ee9de8c3c925420b50e9259c306f4f309c67e8e Mon Sep 17 00:00:00 2001 From: Patrick Haller Date: Thu, 2 Nov 2023 11:38:46 +0000 Subject: [PATCH] Working draft --- src/fabricator/dataset_generator.py | 52 +++++++++++++++++------------ src/fabricator/nodes/__init__.py | 7 ++++ src/fabricator/nodes/base.py | 24 +++++++++++++ src/fabricator/nodes/nodes.py | 49 +++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 22 deletions(-) create mode 100644 src/fabricator/nodes/__init__.py create mode 100644 src/fabricator/nodes/base.py create mode 100644 src/fabricator/nodes/nodes.py diff --git a/src/fabricator/dataset_generator.py b/src/fabricator/dataset_generator.py index 0a46420..613d654 100644 --- a/src/fabricator/dataset_generator.py +++ b/src/fabricator/dataset_generator.py @@ -61,7 +61,7 @@ def generate( num_samples_to_generate: int = 10, timeout_per_prompt: Optional[int] = None, log_every_n_api_calls: int = 25, - dummy_response: Optional[Union[str, Callable]] = None + dummy_response: Optional[Union[str, Callable]] = None, ) -> Union[Dataset, Tuple[Dataset, Dataset]]: """Generate a dataset based on a prompt template and support examples. Optionally, unlabeled examples can be provided to annotate unlabeled data. @@ -93,8 +93,11 @@ def generate( if fewshot_dataset: self._assert_fewshot_dataset_matches_prompt(prompt_template, fewshot_dataset) - assert fewshot_sampling_strategy in [None, "uniform", "stratified"], \ - "Sampling strategy must be 'uniform' or 'stratified'" + assert fewshot_sampling_strategy in [ + None, + "uniform", + "stratified", + ], "Sampling strategy must be 'uniform' or 'stratified'" if fewshot_dataset and not fewshot_sampling_column: fewshot_sampling_column = prompt_template.generate_data_for_column[0] @@ -111,7 +114,7 @@ def generate( num_samples_to_generate, timeout_per_prompt, log_every_n_api_calls, - dummy_response + dummy_response, ) if return_unlabeled_dataset: @@ -134,7 +137,6 @@ def _try_generate( """ if dummy_response: - if isinstance(dummy_response, str): logger.info(f"Returning dummy response: {dummy_response}") return dummy_response @@ -152,7 +154,7 @@ def _try_generate( prediction = self.prompt_node.run( prompt_template=HaystackPromptTemplate(prompt=prompt_text), invocation_context=invocation_context, - )[0]["results"] + ) except Exception as error: logger.error(f"Error while generating example: {error}") return None @@ -172,7 +174,7 @@ def _inner_generate_loop( num_samples_to_generate: int, timeout_per_prompt: Optional[int], log_every_n_api_calls: int = 25, - dummy_response: Optional[Union[str, Callable]] = None + dummy_response: Optional[Union[str, Callable]] = None, ): current_tries_left = self._max_tries current_log_file = self._setup_log(prompt_template) @@ -200,8 +202,11 @@ def _inner_generate_loop( if fewshot_dataset: prompt_labels, fewshot_examples = self._sample_fewshot_examples( - prompt_template, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class, - fewshot_sampling_column + prompt_template, + fewshot_dataset, + fewshot_sampling_strategy, + fewshot_examples_per_class, + fewshot_sampling_column, ) prompt_text = prompt_template.get_prompt_text(prompt_labels, fewshot_examples) @@ -231,6 +236,7 @@ def _inner_generate_loop( f" {len(generated_dataset)} examples." ) break + continue if len(prediction) == 1: prediction = prediction[0] @@ -310,8 +316,9 @@ def _convert_prediction(self, prediction: str, target_type: type) -> Any: return target_type(prediction) except ValueError: logger.warning( - "Could not convert prediction {} to type {}. " - "Returning original prediction.", repr(prediction), target_type + "Could not convert prediction {} to type {}. " "Returning original prediction.", + repr(prediction), + target_type, ) return prediction @@ -321,21 +328,20 @@ def _sample_fewshot_examples( fewshot_dataset: Dataset, fewshot_sampling_strategy: str, fewshot_examples_per_class: int, - fewshot_sampling_column: str + fewshot_sampling_column: str, ) -> Tuple[Union[List[str], str], Dataset]: - if fewshot_sampling_strategy == "uniform": prompt_labels = choice(prompt_template.label_options, 1)[0] - fewshot_examples = fewshot_dataset.filter( - lambda example: example[fewshot_sampling_column] == prompt_labels - ).shuffle().select(range(fewshot_examples_per_class)) + fewshot_examples = ( + fewshot_dataset.filter(lambda example: example[fewshot_sampling_column] == prompt_labels) + .shuffle() + .select(range(fewshot_examples_per_class)) + ) elif fewshot_sampling_strategy == "stratified": prompt_labels = prompt_template.label_options fewshot_examples = single_label_stratified_sample( - fewshot_dataset, - fewshot_sampling_column, - fewshot_examples_per_class + fewshot_dataset, fewshot_sampling_column, fewshot_examples_per_class ) else: @@ -345,9 +351,11 @@ def _sample_fewshot_examples( else: fewshot_examples = fewshot_dataset.shuffle() - assert len(fewshot_examples) > 0, f"Could not find any fewshot examples for label(s) {prompt_labels}." \ - f"Ensure that labels of fewshot examples match the label_options " \ - f"from the prompt." + assert len(fewshot_examples) > 0, ( + f"Could not find any fewshot examples for label(s) {prompt_labels}." + f"Ensure that labels of fewshot examples match the label_options " + f"from the prompt." + ) return prompt_labels, fewshot_examples diff --git a/src/fabricator/nodes/__init__.py b/src/fabricator/nodes/__init__.py new file mode 100644 index 0000000..e5d350d --- /dev/null +++ b/src/fabricator/nodes/__init__.py @@ -0,0 +1,7 @@ +from .base import PromptNode +from .nodes import GuidedPromptNode + +__all__ = [ + "PromptNode", + "GuidedPromptNode", +] diff --git a/src/fabricator/nodes/base.py b/src/fabricator/nodes/base.py new file mode 100644 index 0000000..1ee16e5 --- /dev/null +++ b/src/fabricator/nodes/base.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + +from typing import Optional, Dict, Union + +from haystack.nodes import PromptNode as HaystackPromptNode +from haystack.nodes.prompt import PromptTemplate as HaystackPromptTemplate + + +class Node(ABC): + @abstractmethod + def run(self, prompt_template): + pass + + +class PromptNode(Node): + def __init__(self, model_name_or_path: str, *args, **kwargs) -> None: + self._prompt_node = HaystackPromptNode(model_name_or_path, *args, **kwargs) + + def run( + self, + prompt_template: Optional[Union[str, HaystackPromptTemplate]], + invocation_context: Optional[Dict[str, any]] = None, + ): + return self._prompt_node.run(prompt_template, invocation_context=invocation_context)[0]["results"] diff --git a/src/fabricator/nodes/nodes.py b/src/fabricator/nodes/nodes.py new file mode 100644 index 0000000..9ecc568 --- /dev/null +++ b/src/fabricator/nodes/nodes.py @@ -0,0 +1,49 @@ +import json + +from typing import Optional, Dict, Union + +from haystack.nodes import PromptNode as HaystackPromptNode +from haystack.nodes.prompt import PromptTemplate as HaystackPromptTemplate + +from .base import Node + +try: + import outlines.models as models + import outlines.text.generate as generate + + from pydantic import BaseModel + + import torch +except ImportError as exc: + raise ImportError("Try 'pip install outlines'") from exc + + +class GuidedPromptNode(Node): + def __init__( + self, + model_name_or_path: str, + schema: Union[str, BaseModel], + max_length: int = 100, + device: Optional[str] = None, + model_kwargs: Dict = None, + manual_seed: Optional[int] = None, + ) -> None: + self.max_length = max_length + model_kwargs = model_kwargs or {} + self._model = models.transformers(model_name_or_path, device=device, **model_kwargs) + # JSON schema of class + if not isinstance(schema, str): + schema = json.dumps(schema.schema()) + + self._generator = generate.json( + self._model, + schema, + self.max_length, + ) + + self.rng = torch.Generator(device=device) + if manual_seed is not None: + self.rng.manual_seed(manual_seed) + + def run(self, prompt_template: HaystackPromptTemplate, **kwargs): + return self._generator(prompt_template.prompt_text, rng=self.rng)