Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes in client api #273

Merged
merged 9 commits into from
Jan 22, 2021
139 changes: 74 additions & 65 deletions metagraph/api/python/metagraph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

"""Metagraph client."""

DEFAULT_TOP_LABELS = 10000
DEFAULT_DISCOVERY_THRESHOLD = 0.7
DEFAULT_TOP_LABELS = 100
DEFAULT_DISCOVERY_THRESHOLD = 0
DEFAULT_NUM_NODES_PER_SEQ_CHAR = 10.0

JsonDict = Dict[str, Any]
Expand All @@ -25,29 +25,41 @@ class GraphClientJson:
returning error message in the second element of the tuple returned.
"""

def __init__(self, host: str, port: int, label: str = None, api_path: str = None):
def __init__(self, host: str, port: int, name: str = None, api_path: str = None):
self.host = host
self.port = port
self.label = label
self.api_path = api_path

self.server = f"http://{self.host}:{self.port}"
if api_path:
self.server = f"{self.server}/{api_path.lstrip('/')}"

self.name = name if name else self.server

def search(self, sequence: Union[str, Iterable[str]],
top_labels: int = DEFAULT_TOP_LABELS,
discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD,
align: bool = False,
max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> Tuple[JsonDict, str]:
**align_params) -> Tuple[JsonDict, str]:
"""See parameters for alignment `align_params` in align()"""

if discovery_threshold < 0.0 or discovery_threshold > 1.0:
raise ValueError(
f"discovery_threshold should be between 0 and 1 inclusive. Got {discovery_threshold}")

if max_num_nodes_per_seq_char < 0:
warnings.warn("max_num_nodes_per_seq_char < 0, treating as infinite", RuntimeWarning)
if align:
json_obj = self.align(sequence, discovery_threshold, **align_params)

def to_fasta(df):
fasta = []
for i in range(df.shape[0]):
fasta.append(f">{df.loc[i, 'seq_description']}\n{df.loc[i, 'sequence']}")
return '\n'.join(fasta)

sequence = to_fasta(helpers.df_from_align_result(json_obj))

param_dict = {"count_labels": True,
"discovery_fraction": discovery_threshold,
"num_labels": top_labels,
"align": align,
"max_num_nodes_per_seq_char": max_num_nodes_per_seq_char}
"num_labels": top_labels}

return self._json_seq_query(sequence, param_dict, "search")

Expand All @@ -65,19 +77,16 @@ def align(self, sequence: Union[str, Iterable[str]],
params = {'max_alternative_alignments': max_alternative_alignments,
'max_num_nodes_per_seq_char': max_num_nodes_per_seq_char,
'discovery_fraction': discovery_threshold}
return self._json_seq_query(sequence, params, "align")

# noinspection PyTypeChecker
def column_labels(self) -> Tuple[JsonStrList, str]:
return self._do_request("column_labels", {}, False)
return self._json_seq_query(sequence, params, "align")

def _json_seq_query(self, sequence: Union[str, Iterable[str]], param_dict,
endpoint: str) -> Tuple[JsonDict, str]:
if isinstance(sequence, str):
fasta_str = f">query\n{sequence}"
else:
fasta_str = '\n'.join(
[f">{i}\n{seq}" for (i, seq) in enumerate(sequence)])
sequence = [sequence]

fasta_str = '\n'.join(
[f">{i}\n{seq}" for (i, seq) in enumerate(sequence)])

payload_dict = {"FASTA": fasta_str}
payload_dict.update(param_dict)
Expand All @@ -86,12 +95,7 @@ def _json_seq_query(self, sequence: Union[str, Iterable[str]], param_dict,
return self._do_request(endpoint, payload)

def _do_request(self, endpoint, payload, post_req=True) -> Tuple[JsonDict, str]:
endpoint_path = endpoint

if self.api_path:
endpoint_path = f"{self.api_path.lstrip('/')}/{endpoint}"

url = f'http://{self.host}:{self.port}/{endpoint_path}'
url = f'{self.server}/{endpoint}'
if post_req:
ret = requests.post(url=url, json=payload)
else:
Expand All @@ -100,70 +104,76 @@ def _do_request(self, endpoint, payload, post_req=True) -> Tuple[JsonDict, str]:
try:
json_obj = ret.json()
except:
return {}, str(ret.status_code) + " " + str(ret)
raise RuntimeError(
f"Error while calling the server API. {str(ret.status_code)}: {ret.text}")

if not ret.ok:
error_msg = json_obj[
'error'] if 'error' in json_obj.keys() else str(json_obj)
return {}, str(ret.status_code) + " " + error_msg
error_msg = json_obj['error'] if 'error' in json_obj.keys() else str(json_obj)
raise RuntimeError(
f"Error while calling the server API. {str(ret.status_code)}: {error_msg}")

return json_obj, ""
return json_obj

# noinspection PyTypeChecker
def column_labels(self) -> Tuple[JsonStrList, str]:
return self._do_request("column_labels", {}, post_req=False)

def stats(self) -> Tuple[dict, str]:
return self._do_request("stats", {}, post_req=False)

def ready(self) -> bool:
try:
self.stats()
return True
except RuntimeError as e:
if "503: Server is currently initializing" in str(e):
return False
raise e


class GraphClient:
def __init__(self, host: str, port: int, label: str = None, api_path: str = None):
self._json_client = GraphClientJson(host, port, api_path=api_path)
self.label = label
def __init__(self, host: str, port: int, name: str = None, api_path: str = None):
self._json_client = GraphClientJson(host, port, name, api_path=api_path)
self.name = self._json_client.name

def search(self, sequence: Union[str, Iterable[str]],
top_labels: int = DEFAULT_TOP_LABELS,
discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD,
align: bool = False,
max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> pd.DataFrame:
(json_obj, err) = self._json_client.search(sequence, top_labels,
discovery_threshold, align,
max_num_nodes_per_seq_char)
**align_params) -> pd.DataFrame:
"""See parameters for alignment `align_params` in align()"""

if err:
raise RuntimeError(
f"Error while calling the server API {str(err)}")
json_obj = self._json_client.search(sequence, top_labels,
discovery_threshold,
align, **align_params)

return helpers.df_from_search_result(json_obj)

def align(self, sequence: Union[str, Iterable[str]],
discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD,
max_alternative_alignments: int = 1,
max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> pd.DataFrame:
json_obj, err = self._json_client.align(sequence, discovery_threshold,
max_alternative_alignments, max_num_nodes_per_seq_char)

if err:
raise RuntimeError(f"Error while calling the server API {str(err)}")
json_obj = self._json_client.align(sequence, discovery_threshold,
max_alternative_alignments,
max_num_nodes_per_seq_char)

return helpers.df_from_align_result(json_obj)


def column_labels(self) -> List[str]:
json_obj, err = self._json_client.column_labels()
return self._json_client.column_labels()

if err:
raise RuntimeError(f"Error while calling the server API {str(err)}")
return json_obj
def ready(self) -> bool:
return self._json_client.ready()


class MultiGraphClient:
# TODO: make things asynchronously. this should be the added value of this class
def __init__(self):
self.graphs = {}

def add_graph(self, host: str, port: int, label: str = None, api_path: str = None) -> None:
if not label:
label = f"{host}:{port}"

self.graphs[label] = GraphClient(host, port, label, api_path=api_path)
def add_graph(self, host: str, port: int, name: str = None, api_path: str = None) -> None:
graph_client = GraphClient(host, port, name, api_path=api_path)
self.graphs[graph_client.name] = graph_client

def list_graphs(self) -> Dict[str, Tuple[str, int]]:
return {lbl: (inst.host, inst.port) for (lbl, inst) in
Expand All @@ -173,15 +183,14 @@ def search(self, sequence: Union[str, Iterable[str]],
top_labels: int = DEFAULT_TOP_LABELS,
discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD,
align: bool = False,
max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> \
Dict[str, pd.DataFrame]:
**align_params) -> Dict[str, pd.DataFrame]:
"""See parameters for alignment `align_params` in align()"""

result = {}
for label, graph_client in self.graphs.items():
result[label] = graph_client.search(sequence, top_labels,
for name, graph_client in self.graphs.items():
result[name] = graph_client.search(sequence, top_labels,
discovery_threshold,
align,
max_num_nodes_per_seq_char)
align, **align_params)

return result

Expand All @@ -191,17 +200,17 @@ def align(self, sequence: Union[str, Iterable[str]],
max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> Dict[
str, pd.DataFrame]:
result = {}
for label, graph_client in self.graphs.items():
for name, graph_client in self.graphs.items():
# TODO: do this async
result[label] = graph_client.align(sequence, discovery_threshold,
result[name] = graph_client.align(sequence, discovery_threshold,
max_alternative_alignments,
max_num_nodes_per_seq_char)

return result

def column_labels(self) -> Dict[str, List[str]]:
ret = {}
for label, graph_client in self.graphs.items():
ret[label] = graph_client.column_labels()
for name, graph_client in self.graphs.items():
ret[name] = graph_client.column_labels()

return ret
27 changes: 8 additions & 19 deletions metagraph/api/python/metagraph/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,32 @@
def df_from_search_result(json_res):
def _build_dict(row, result):
d = dict(row)
# `properties` (optional): dictionary with metadata about the sample
if 'properties' in d.keys():
props = d.pop('properties')
else:
props = {}

props['seq_description'] = result['seq_description']

if props['seq_description'] == 'query':
props['seq_description'] = 0 # for consistency

if 'cigar' in result.keys():
# we did alignment
props['sequence'] = result['sequence']
props['score'] = result['score']
props['cigar'] = result['cigar']

return {**d, **props}

lst = [_build_dict(row, result) for result in json_res for row in
result['results']]

if lst:
return pd.DataFrame(lst)
else:
# columns may vary on the graph, so not adding column information
return pd.DataFrame()
# columns of the table may vary on the graph, and inferred automatically
return pd.DataFrame(lst)


def df_from_align_result(json_res):
# flatten out json result
lst = [alignment for result in json_res for alignment in
result['alignments']]
lst = [(alignment['cigar'],
alignment['score'],
alignment['sequence'],
result['seq_description'])
for result in json_res for alignment in result['alignments']]

df = pd.DataFrame(lst,
columns=['cigar', 'score', 'sequence', 'seq_description'])

# for consistency, set seq_description to 0 even if we only queried a single sequence
df.loc[df['seq_description'] == 'query', 'seq_description'] = 0

return df
15 changes: 9 additions & 6 deletions metagraph/integration_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ def test_api_simple_query_align_df(self):
ret = self.graph_client.search(self.sample_query, discovery_threshold=0.01, align=True)
df = ret[self.graph_name]

self.assertIn('cigar', df.columns)
self.assertEqual((self.sample_query_expected_rows, 6), df.shape)
self.assertEqual((self.sample_query_expected_rows, 3), df.shape)

def test_api_client_column_labels(self):
ret = self.graph_client.column_labels()
Expand Down Expand Up @@ -274,13 +273,17 @@ def setUpClass(cls):

cls.graph_client = GraphClientJson(cls.host, cls.port)

def setUp(self):
if not self.graph_client.ready():
self.fail("Server takes too long to initialize")

def test_api_align_json(self):
ret, _ = self.graph_client.align("TCGATCGA")
ret = self.graph_client.align("TCGATCGA")
self.assertEqual(len(ret), 1)

# do various queries
def test_api_simple_query(self):
res_list, _ = self.graph_client.search(self.sample_query, discovery_threshold=0.01)
res_list = self.graph_client.search(self.sample_query, discovery_threshold=0.01)

self.assertEqual(len(res_list), 1)

Expand All @@ -297,14 +300,14 @@ def test_api_simple_query(self):
def test_api_multiple_queries(self):
repetitions = 4

res_list, _ = self.graph_client.search([self.sample_query] * repetitions)
res_list = self.graph_client.search([self.sample_query] * repetitions)
self.assertEqual(len(res_list), repetitions)

# testing if the returned query indices range from 0 to n - 1
self.assertEqual(sorted(range(0, repetitions)), sorted([int(a['seq_description']) for a in res_list]))

def test_api_stats(self):
res = self.graph_client.stats()[0]
res = self.graph_client.stats()

self.assertIn("graph", res.keys())
graph_props = res['graph']
Expand Down
1 change: 1 addition & 0 deletions metagraph/src/cli/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ std::thread start_server(HttpServer &server_startup, Config &config) {
template<typename T>
bool check_data_ready(std::shared_future<T> &data, shared_ptr<HttpServer::Response> response) {
if (data.wait_for(0s) != std::future_status::ready) {
logger->info("[Server] Got a request during initialization. Asked to come back later");
response->write(SimpleWeb::StatusCode::server_error_service_unavailable,
"Server is currently initializing, please come back later.");
return false;
Expand Down