From 79021a3761b3fed6da7d59fc7fb738dcabc73815 Mon Sep 17 00:00:00 2001 From: Davide Fusca Date: Sun, 10 Nov 2024 11:54:07 +0100 Subject: [PATCH 1/3] Add points put endpoint --- core/cat/routes/memory/points.py | 94 +++++++++++++++++++ .../tests/routes/memory/test_memory_points.py | 82 ++++++++++++++++ 2 files changed, 176 insertions(+) diff --git a/core/cat/routes/memory/points.py b/core/cat/routes/memory/points.py index 5f20724b4..ad8b200ac 100644 --- a/core/cat/routes/memory/points.py +++ b/core/cat/routes/memory/points.py @@ -267,3 +267,97 @@ async def get_points_in_collection( "points": points, "next_offset": next_offset } + + + +# EDIT a point in memory +@router.put("/collections/{collection_id}/points/{point_id}", response_model=MemoryPoint) +async def edit_memory_point( + request: Request, + collection_id: str, + point_id: str, + point: MemoryPointBase, + stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.WRITE)), +) -> MemoryPoint: + """Edit a point in memory + + + Example + ---------- + ``` + + collection = "declarative" + content = "MIAO!" + metadata = {"custom_key": "custom_value"} + req_json = { + "content": content, + "metadata": metadata, + } + # create a point + res = requests.post( + f"http://localhost:1865/memory/collections/{collection}/points", json=req_json + ) + json = res.json() + #get the id + point_id = json["id"] + # new point values + content = "NEW MIAO!" + metadata = {"custom_key": "new_custom_value"} + req_json = { + "content": content, + "metadata": metadata, + } + + # edit the point + res = requests.put( + f"http://localhost:1865/memory/collections/{collection}/points/{point_id}", json=req_json + ) + json = res.json() + print(json) + ``` + """ + + # do not touch procedural memory + if collection_id == "procedural": + raise HTTPException( + status_code=400, detail={"error": "Procedural memory is read-only."} + ) + + vector_memory: VectorMemory = stray.memory.vectors + collections = list(vector_memory.collections.keys()) + if collection_id not in collections: + raise HTTPException( + status_code=400, detail={"error": "Collection does not exist."} + ) + + #ensure point exist + points = vector_memory.collections[collection_id].get_points([point_id]) + if points is None or len(points) == 0: + raise HTTPException( + status_code=400, detail={"error": f"Point with id {point_id} does not exist."} + ) + + # embed content + embedding = stray.embedder.embed_query(point.content) + + # ensure source is set + if not point.metadata.get("source"): + point.metadata["source"] = ( + stray.user_id + ) # this will do also for declarative memory + + # ensure when is set + if not point.metadata.get("when"): + point.metadata["when"] = time.time() #if when is not in the metadata set the current time + + # edit point + qdrant_point = vector_memory.collections[collection_id].add_point( + content=point.content, vector=embedding, metadata=point.metadata, id=point_id + ) + + return MemoryPoint( + metadata=qdrant_point.payload["metadata"], + content=qdrant_point.payload["page_content"], + vector=qdrant_point.vector, + id=qdrant_point.id, + ) \ No newline at end of file diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index 0d8fd827a..483176c0b 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -288,3 +288,85 @@ def test_get_collection_points_offset(client, patch_time_now, collection): assert points_payloads == expected_payloads + +def test_edit_point_wrong_collection_and_not_exist(client): + + req_json = { + "content": "MIAO!" + } + + point_id = 100 + + # wrong collection + res = client.put( + f"/memory/collections/wrongcollection/points/{point_id}", json=req_json + ) + assert res.status_code == 400 + assert "Collection does not exist" in res.json()["detail"]["error"] + + # cannot write procedural point + res = client.put( + "/memory/collections/procedural/points/{point_id}", json=req_json + ) + assert res.status_code == 400 + assert "Procedural memory is read-only" in res.json()["detail"]["error"] + + # point do not exist + res = client.put( + "/memory/collections/declarative/points/{point_id}", json=req_json + ) + assert res.status_code == 400 + assert "Point with id {point_id} does not exist." in res.json()["detail"]["error"] + + + +@pytest.mark.parametrize("collection", ["episodic", "declarative"]) +def test_edit_memory_point(client, patch_time_now, collection): + + # create a point + content = "MIAO!" + metadata = {"custom_key": "custom_value"} + req_json = { + "content": content, + "metadata": metadata, + } + # create a point + res = client.post( + f"/memory/collections/{collection}/points", json=req_json + ) + #get the id + assert res.status_code == 200 + json = res.json() + assert json["id"] + point_id = json["id"] + # new point values + content = "NEW MIAO!" + metadata = {"custom_key": "new_custom_value"} + req_json = { + "content": content, + "metadata": metadata, + } + + res = client.put( + f"/memory/collections/{collection}/points/{point_id}", json=req_json + ) + # check response + assert res.status_code == 200 + json = res.json() + assert json["content"] == content + expected_metadata = {"when":FAKE_TIMESTAMP,"source": "user", **metadata} + assert json["metadata"] == expected_metadata + assert "id" in json + assert "vector" in json + assert isinstance(json["vector"], list) + assert isinstance(json["vector"][0], float) + + # check memory contents + params = {"text": "miao"} + response = client.get("/memory/recall/", params=params) + json = response.json() + assert response.status_code == 200 + assert len(json["vectors"]["collections"][collection]) == 1 + memory = json["vectors"]["collections"][collection][0] + assert memory["page_content"] == content + assert memory["metadata"] == expected_metadata \ No newline at end of file From a3b5fbe16269b04ef7d1904b1fa665bb7e4dd5ad Mon Sep 17 00:00:00 2001 From: Davide Fusca Date: Sun, 10 Nov 2024 12:01:25 +0100 Subject: [PATCH 2/3] Update Auth to EDIT --- core/cat/routes/memory/points.py | 4 ++-- core/tests/routes/memory/test_memory_points.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/cat/routes/memory/points.py b/core/cat/routes/memory/points.py index ad8b200ac..6ba208bcf 100644 --- a/core/cat/routes/memory/points.py +++ b/core/cat/routes/memory/points.py @@ -277,7 +277,7 @@ async def edit_memory_point( collection_id: str, point_id: str, point: MemoryPointBase, - stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.WRITE)), + stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.EDIT)), ) -> MemoryPoint: """Edit a point in memory @@ -334,7 +334,7 @@ async def edit_memory_point( points = vector_memory.collections[collection_id].get_points([point_id]) if points is None or len(points) == 0: raise HTTPException( - status_code=400, detail={"error": f"Point with id {point_id} does not exist."} + status_code=400, detail={"error": f"Point does not exist."} ) # embed content diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index 483176c0b..c204e5615 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -316,7 +316,7 @@ def test_edit_point_wrong_collection_and_not_exist(client): "/memory/collections/declarative/points/{point_id}", json=req_json ) assert res.status_code == 400 - assert "Point with id {point_id} does not exist." in res.json()["detail"]["error"] + assert f"Point does not exist." in res.json()["detail"]["error"] From 36e70e39294b836a17f208d0f1575aeda789fb7e Mon Sep 17 00:00:00 2001 From: Davide Fusca Date: Sun, 10 Nov 2024 12:32:14 +0100 Subject: [PATCH 3/3] Fix linter --- core/cat/routes/memory/points.py | 2 +- core/tests/routes/memory/test_memory_points.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/cat/routes/memory/points.py b/core/cat/routes/memory/points.py index 6ba208bcf..8ad65f563 100644 --- a/core/cat/routes/memory/points.py +++ b/core/cat/routes/memory/points.py @@ -334,7 +334,7 @@ async def edit_memory_point( points = vector_memory.collections[collection_id].get_points([point_id]) if points is None or len(points) == 0: raise HTTPException( - status_code=400, detail={"error": f"Point does not exist."} + status_code=400, detail={"error": "Point does not exist."} ) # embed content diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index c204e5615..7aff691db 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -316,7 +316,7 @@ def test_edit_point_wrong_collection_and_not_exist(client): "/memory/collections/declarative/points/{point_id}", json=req_json ) assert res.status_code == 400 - assert f"Point does not exist." in res.json()["detail"]["error"] + assert "Point does not exist." in res.json()["detail"]["error"]