Skip to content

Commit

Permalink
Refactor and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
john681611 committed Oct 23, 2023
1 parent 44feb20 commit 609bda0
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 12 deletions.
15 changes: 5 additions & 10 deletions application/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,16 +494,12 @@ def format_path_record(rec):

@classmethod
def standards(self) -> List[str]:
# TODO (JOHN) REDUCE DUPLICATION & SIMPLIFY
tools = []
results = []
for x in db.cypher_query("""MATCH (n:NeoTool) RETURN DISTINCT n.name""")[0]:
tools.extend(x)
standards = []
for x in db.cypher_query("""MATCH (n:NeoStandard) RETURN DISTINCT n.name""")[
0
]: # 0 is the results, 1 is the "n.name" param
standards.extend(x)
return list(set([x for x in tools] + [x for x in standards]))
results.extend(x)
for x in db.cypher_query("""MATCH (n:NeoStandard) RETURN DISTINCT n.name""")[0]:
results.extend(x)
return list(set(results))

@staticmethod
def parse_node(node: NeoDocument) -> cre_defs.Document:
Expand Down Expand Up @@ -1812,7 +1808,6 @@ def gap_analysis(
store_in_cache
): # lightweight memory option to not return potentially huge object and instead store in a cache,
# in case this is called via worker, we save both this and the caller memory by avoiding duplicate object in mem
# TODO (JOHN) MOCK AND TEST REDIS CALLS
conn = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379"))
if cache_key == "":
cache_key = make_array_hash(node_names)
Expand Down
109 changes: 108 additions & 1 deletion application/tests/web_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import patch

import redis
import rq

from application import create_app, sqla # type: ignore
from application.database import db
Expand All @@ -12,6 +13,12 @@
from application.web import web_main


class MockJob:
@property
def id(self):
return "ABC"


class TestMain(unittest.TestCase):
def tearDown(self) -> None:
sqla.session.remove()
Expand Down Expand Up @@ -566,7 +573,107 @@ def test_smartlink(self) -> None:
self.assertEqual(location, "")
self.assertEqual(404, response.status_code)

# TODO: (JOHN) Basic gap analysis endpoint tests
@patch.object(redis, "from_url")
def test_gap_analysis_from_cache_full_response(self, redis_conn_mock) -> None:
expected = {"result": "hello"}
redis_conn_mock.return_value.exists.return_value = True
redis_conn_mock.return_value.get.return_value = json.dumps(expected)
with self.app.test_client() as client:
response = client.get(
"/rest/v1/map_analysis?standard=aaa&standard=bbb",
headers={"Content-Type": "application/json"},
)
self.assertEqual(200, response.status_code)
self.assertEqual(expected, json.loads(response.data))

@patch.object(rq.Queue, "enqueue_call")
@patch.object(redis, "from_url")
def test_gap_analysis_from_cache_job_id(
self, redis_conn_mock, enqueue_call_mock
) -> None:
expected = {"job_id": "hello"}
redis_conn_mock.return_value.exists.return_value = True
redis_conn_mock.return_value.get.return_value = json.dumps(expected)
with self.app.test_client() as client:
response = client.get(
"/rest/v1/map_analysis?standard=aaa&standard=bbb",
headers={"Content-Type": "application/json"},
)
self.assertEqual(200, response.status_code)
self.assertEqual(expected, json.loads(response.data))
self.assertFalse(enqueue_call_mock.called)

@patch.object(db, "Node_collection")
@patch.object(rq.Queue, "enqueue_call")
@patch.object(redis, "from_url")
def test_gap_analysis_create_job_id(
self, redis_conn_mock, enqueue_call_mock, db_mock
) -> None:
expected = {"job_id": "ABC"}
redis_conn_mock.return_value.exists.return_value = False
enqueue_call_mock.return_value = MockJob()
with self.app.test_client() as client:
response = client.get(
"/rest/v1/map_analysis?standard=aaa&standard=bbb",
headers={"Content-Type": "application/json"},
)
self.assertEqual(200, response.status_code)
self.assertEqual(expected, json.loads(response.data))
enqueue_call_mock.assert_called_with(
db.gap_analysis,
kwargs={
"neo_db": db_mock().neo_db,
"node_names": ["aaa", "bbb"],
"store_in_cache": True,
"cache_key": "7aa45d88f69a131890f8e4a769bbb07b",
},
)
redis_conn_mock.return_value.set.assert_called_with(
"7aa45d88f69a131890f8e4a769bbb07b", '{"job_id": "ABC", "result": ""}'
)

@patch.object(redis, "from_url")
def test_standards_from_cache(self, redis_conn_mock) -> None:
expected = ["A", "B"]
redis_conn_mock.return_value.exists.return_value = True
redis_conn_mock.return_value.get.return_value = json.dumps(expected)
with self.app.test_client() as client:
response = client.get(
"/rest/v1/standards",
headers={"Content-Type": "application/json"},
)
self.assertEqual(200, response.status_code)
self.assertEqual(expected, json.loads(response.data))

@patch.object(redis, "from_url")
@patch.object(db, "Node_collection")
def test_standards_from_db(self, node_mock, redis_conn_mock) -> None:
expected = ["A", "B"]
redis_conn_mock.return_value.get.return_value = None
node_mock.return_value.standards.return_value = expected
with self.app.test_client() as client:
response = client.get(
"/rest/v1/standards",
headers={"Content-Type": "application/json"},
)
self.assertEqual(200, response.status_code)
self.assertEqual(expected, json.loads(response.data))

@patch.object(redis, "from_url")
@patch.object(db, "Node_collection")
def test_standards_from_db_off(self, node_mock, redis_conn_mock) -> None:
expected = {
"message": "Backend services connected to this feature are not running at the moment."
}
redis_conn_mock.return_value.get.return_value = None
node_mock.return_value.standards.return_value = None
with self.app.test_client() as client:
response = client.get(
"/rest/v1/standards",
headers={"Content-Type": "application/json"},
)
self.assertEqual(500, response.status_code)
self.assertEqual(expected, json.loads(response.data))

def test_gap_analysis_weak_links_no_cache(self) -> None:
with self.app.test_client() as client:
Expand Down
2 changes: 1 addition & 1 deletion application/web/web_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def standards() -> Any:
database = db.Node_collection()
standards = database.standards()
if standards is None:
neo4j_not_running_rejection()
return neo4j_not_running_rejection()
conn.set("NodeNames", flask_json.dumps(standards))
return standards

Expand Down

0 comments on commit 609bda0

Please sign in to comment.