Skip to content

Commit

Permalink
feat: add dataset functions (#59)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Oct 14, 2024
1 parent 97d819c commit cd76120
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 2 deletions.
25 changes: 25 additions & 0 deletions gptscript/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Dict
from pydantic import BaseModel

class DatasetElementMeta(BaseModel):
name: str
description: str


class DatasetElement(BaseModel):
name: str
description: str
contents: str


class DatasetMeta(BaseModel):
id: str
name: str
description: str


class Dataset(BaseModel):
id: str
name: str
description: str
elements: Dict[str, DatasetElementMeta]
81 changes: 81 additions & 0 deletions gptscript/gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from gptscript.confirm import AuthResponse
from gptscript.credentials import Credential, to_credential
from gptscript.datasets import DatasetMeta, Dataset, DatasetElementMeta, DatasetElement
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
from gptscript.opts import GlobalOptions
from gptscript.prompt import PromptResponse
Expand Down Expand Up @@ -210,6 +211,86 @@ async def delete_credential(self, context: str = "default", name: str = "") -> s
{"context": [context], "name": name}
)

async def list_datasets(self, workspace: str) -> List[DatasetMeta]:
if workspace == "":
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]

res = await self._run_basic_command(
"datasets",
{"input": "{}", "workspace": workspace, "datasetToolRepo": self.opts.DatasetToolRepo}
)
return [DatasetMeta.model_validate(d) for d in json.loads(res)]

async def create_dataset(self, workspace: str, name: str, description: str = "") -> Dataset:
if workspace == "":
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]

if name == "":
raise ValueError("name cannot be empty")

res = await self._run_basic_command(
"datasets/create",
{"input": json.dumps({"datasetName": name, "datasetDescription": description}),
"workspace": workspace,
"datasetToolRepo": self.opts.DatasetToolRepo}
)
return Dataset.model_validate_json(res)

async def add_dataset_element(self, workspace: str, datasetID: str, elementName: str, elementContent: str,
elementDescription: str = "") -> DatasetElementMeta:
if workspace == "":
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]

if datasetID == "":
raise ValueError("datasetID cannot be empty")
elif elementName == "":
raise ValueError("elementName cannot be empty")
elif elementContent == "":
raise ValueError("elementContent cannot be empty")

res = await self._run_basic_command(
"datasets/add-element",
{"input": json.dumps({"datasetID": datasetID,
"elementName": elementName,
"elementContent": elementContent,
"elementDescription": elementDescription}),
"workspace": workspace,
"datasetToolRepo": self.opts.DatasetToolRepo}
)
return DatasetElementMeta.model_validate_json(res)

async def list_dataset_elements(self, workspace: str, datasetID: str) -> List[DatasetElementMeta]:
if workspace == "":
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]

if datasetID == "":
raise ValueError("datasetID cannot be empty")

res = await self._run_basic_command(
"datasets/list-elements",
{"input": json.dumps({"datasetID": datasetID}),
"workspace": workspace,
"datasetToolRepo": self.opts.DatasetToolRepo}
)
return [DatasetElementMeta.model_validate(d) for d in json.loads(res)]

async def get_dataset_element(self, workspace: str, datasetID: str, elementName: str) -> DatasetElement:
if workspace == "":
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]

if datasetID == "":
raise ValueError("datasetID cannot be empty")
elif elementName == "":
raise ValueError("elementName cannot be empty")

res = await self._run_basic_command(
"datasets/get-element",
{"input": json.dumps({"datasetID": datasetID, "element": elementName}),
"workspace": workspace,
"datasetToolRepo": self.opts.DatasetToolRepo}
)
return DatasetElement.model_validate_json(res)


def _get_command():
if os.getenv("GPTSCRIPT_BIN") is not None:
Expand Down
6 changes: 5 additions & 1 deletion gptscript/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
defaultModelProvider: str = "",
defaultModel: str = "",
cacheDir: str = "",
datasetToolRepo: str = "",
env: list[str] = None,
):
self.URL = url
Expand All @@ -21,6 +22,7 @@ def __init__(
self.DefaultModel = defaultModel
self.DefaultModelProvider = defaultModelProvider
self.CacheDir = cacheDir
self.DatasetToolRepo = datasetToolRepo
if env is None:
env = [f"{k}={v}" for k, v in os.environ.items()]
elif isinstance(env, dict):
Expand All @@ -38,6 +40,7 @@ def merge(self, other: Self) -> Self:
cp.DefaultModel = other.DefaultModel if other.DefaultModel != "" else self.DefaultModel
cp.DefaultModelProvider = other.DefaultModelProvider if other.DefaultModelProvider != "" else self.DefaultModelProvider
cp.CacheDir = other.CacheDir if other.CacheDir != "" else self.CacheDir
cp.DatasetToolRepo = other.DatasetToolRepo if other.DatasetToolRepo != "" else self.DatasetToolRepo
cp.Env = (other.Env or [])
cp.Env.extend(self.Env or [])
return cp
Expand Down Expand Up @@ -77,8 +80,9 @@ def __init__(self,
defaultModelProvider: str = "",
defaultModel: str = "",
cacheDir: str = "",
datasetToolDir: str = "",
):
super().__init__(url, token, apiKey, baseURL, defaultModelProvider, defaultModel, cacheDir, env)
super().__init__(url, token, apiKey, baseURL, defaultModelProvider, defaultModel, cacheDir, datasetToolDir, env)
self.input = input
self.disableCache = disableCache
self.subTool = subTool
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ setuptools==69.1.1
twine==5.0.0
build==1.1.1
httpx==0.27.0
pywin32==306; sys_platform == 'win32'
pydantic==2.9.2
pywin32==306; sys_platform == 'win32'
37 changes: 37 additions & 0 deletions tests/test_gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import platform
import subprocess
import tempfile
from datetime import datetime, timedelta, timezone
from time import sleep

Expand Down Expand Up @@ -755,3 +756,39 @@ async def test_credentials(gptscript):

res = await gptscript.delete_credential(name=name)
assert not res.startswith("an error occurred"), "Unexpected error deleting credential: " + res

@pytest.mark.asyncio
async def test_datasets(gptscript):
with tempfile.TemporaryDirectory(prefix="py-gptscript_") as tempdir:
dataset_name = str(os.urandom(8).hex())

# Create dataset
dataset = await gptscript.create_dataset(tempdir, dataset_name, "this is a test dataset")
assert dataset.id != "", "Expected dataset id to be set"
assert dataset.name == dataset_name, "Expected dataset name to match"
assert dataset.description == "this is a test dataset", "Expected dataset description to match"
assert len(dataset.elements) == 0, "Expected dataset elements to be empty"

# Add an element
element_meta = await gptscript.add_dataset_element(tempdir, dataset.id, "element1", "element1 contents", "element1 description")
assert element_meta.name == "element1", "Expected element name to match"
assert element_meta.description == "element1 description", "Expected element description to match"

# Get the element
element = await gptscript.get_dataset_element(tempdir, dataset.id, "element1")
assert element.name == "element1", "Expected element name to match"
assert element.contents == "element1 contents", "Expected element contents to match"
assert element.description == "element1 description", "Expected element description to match"

# List elements in the dataset
elements = await gptscript.list_dataset_elements(tempdir, dataset.id)
assert len(elements) == 1, "Expected one element in the dataset"
assert elements[0].name == "element1", "Expected element name to match"
assert elements[0].description == "element1 description", "Expected element description to match"

# List datasets
datasets = await gptscript.list_datasets(tempdir)
assert len(datasets) > 0, "Expected at least one dataset"
assert datasets[0].id == dataset.id, "Expected dataset id to match"
assert datasets[0].name == dataset_name, "Expected dataset name to match"
assert datasets[0].description == "this is a test dataset", "Expected dataset description to match"
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ deps =
httpx
pytest
pytest-asyncio
pydantic

passenv =
OPENAI_API_KEY
Expand Down

0 comments on commit cd76120

Please sign in to comment.