diff --git a/metagraph/api/python/metagraph/client.py b/metagraph/api/python/metagraph/client.py index cfd6b538db..be9ecd92db 100644 --- a/metagraph/api/python/metagraph/client.py +++ b/metagraph/api/python/metagraph/client.py @@ -10,8 +10,8 @@ """Metagraph client.""" -DEFAULT_TOP_LABELS = 10000 -DEFAULT_DISCOVERY_THRESHOLD = 0.7 +DEFAULT_TOP_LABELS = 100 +DEFAULT_DISCOVERY_THRESHOLD = 0 DEFAULT_NUM_NODES_PER_SEQ_CHAR = 10.0 JsonDict = Dict[str, Any] @@ -25,29 +25,41 @@ class GraphClientJson: returning error message in the second element of the tuple returned. """ - def __init__(self, host: str, port: int, label: str = None, api_path: str = None): + def __init__(self, host: str, port: int, name: str = None, api_path: str = None): self.host = host self.port = port - self.label = label - self.api_path = api_path + + self.server = f"http://{self.host}:{self.port}" + if api_path: + self.server = f"{self.server}/{api_path.lstrip('/')}" + + self.name = name if name else self.server def search(self, sequence: Union[str, Iterable[str]], top_labels: int = DEFAULT_TOP_LABELS, discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD, align: bool = False, - max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> Tuple[JsonDict, str]: + **align_params) -> Tuple[JsonDict, str]: + """See parameters for alignment `align_params` in align()""" + if discovery_threshold < 0.0 or discovery_threshold > 1.0: raise ValueError( f"discovery_threshold should be between 0 and 1 inclusive. Got {discovery_threshold}") - if max_num_nodes_per_seq_char < 0: - warnings.warn("max_num_nodes_per_seq_char < 0, treating as infinite", RuntimeWarning) + if align: + json_obj = self.align(sequence, discovery_threshold, **align_params) + + def to_fasta(df): + fasta = [] + for i in range(df.shape[0]): + fasta.append(f">{df.loc[i, 'seq_description']}\n{df.loc[i, 'sequence']}") + return '\n'.join(fasta) + + sequence = to_fasta(helpers.df_from_align_result(json_obj)) param_dict = {"count_labels": True, "discovery_fraction": discovery_threshold, - "num_labels": top_labels, - "align": align, - "max_num_nodes_per_seq_char": max_num_nodes_per_seq_char} + "num_labels": top_labels} return self._json_seq_query(sequence, param_dict, "search") @@ -65,19 +77,16 @@ def align(self, sequence: Union[str, Iterable[str]], params = {'max_alternative_alignments': max_alternative_alignments, 'max_num_nodes_per_seq_char': max_num_nodes_per_seq_char, 'discovery_fraction': discovery_threshold} - return self._json_seq_query(sequence, params, "align") - # noinspection PyTypeChecker - def column_labels(self) -> Tuple[JsonStrList, str]: - return self._do_request("column_labels", {}, False) + return self._json_seq_query(sequence, params, "align") def _json_seq_query(self, sequence: Union[str, Iterable[str]], param_dict, endpoint: str) -> Tuple[JsonDict, str]: if isinstance(sequence, str): - fasta_str = f">query\n{sequence}" - else: - fasta_str = '\n'.join( - [f">{i}\n{seq}" for (i, seq) in enumerate(sequence)]) + sequence = [sequence] + + fasta_str = '\n'.join( + [f">{i}\n{seq}" for (i, seq) in enumerate(sequence)]) payload_dict = {"FASTA": fasta_str} payload_dict.update(param_dict) @@ -86,12 +95,7 @@ def _json_seq_query(self, sequence: Union[str, Iterable[str]], param_dict, return self._do_request(endpoint, payload) def _do_request(self, endpoint, payload, post_req=True) -> Tuple[JsonDict, str]: - endpoint_path = endpoint - - if self.api_path: - endpoint_path = f"{self.api_path.lstrip('/')}/{endpoint}" - - url = f'http://{self.host}:{self.port}/{endpoint_path}' + url = f'{self.server}/{endpoint}' if post_req: ret = requests.post(url=url, json=payload) else: @@ -100,36 +104,48 @@ def _do_request(self, endpoint, payload, post_req=True) -> Tuple[JsonDict, str]: try: json_obj = ret.json() except: - return {}, str(ret.status_code) + " " + str(ret) + raise RuntimeError( + f"Error while calling the server API. {str(ret.status_code)}: {ret.text}") if not ret.ok: - error_msg = json_obj[ - 'error'] if 'error' in json_obj.keys() else str(json_obj) - return {}, str(ret.status_code) + " " + error_msg + error_msg = json_obj['error'] if 'error' in json_obj.keys() else str(json_obj) + raise RuntimeError( + f"Error while calling the server API. {str(ret.status_code)}: {error_msg}") - return json_obj, "" + return json_obj + + # noinspection PyTypeChecker + def column_labels(self) -> Tuple[JsonStrList, str]: + return self._do_request("column_labels", {}, post_req=False) def stats(self) -> Tuple[dict, str]: return self._do_request("stats", {}, post_req=False) + def ready(self) -> bool: + try: + self.stats() + return True + except RuntimeError as e: + if "503: Server is currently initializing" in str(e): + return False + raise e + class GraphClient: - def __init__(self, host: str, port: int, label: str = None, api_path: str = None): - self._json_client = GraphClientJson(host, port, api_path=api_path) - self.label = label + def __init__(self, host: str, port: int, name: str = None, api_path: str = None): + self._json_client = GraphClientJson(host, port, name, api_path=api_path) + self.name = self._json_client.name def search(self, sequence: Union[str, Iterable[str]], top_labels: int = DEFAULT_TOP_LABELS, discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD, align: bool = False, - max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> pd.DataFrame: - (json_obj, err) = self._json_client.search(sequence, top_labels, - discovery_threshold, align, - max_num_nodes_per_seq_char) + **align_params) -> pd.DataFrame: + """See parameters for alignment `align_params` in align()""" - if err: - raise RuntimeError( - f"Error while calling the server API {str(err)}") + json_obj = self._json_client.search(sequence, top_labels, + discovery_threshold, + align, **align_params) return helpers.df_from_search_result(json_obj) @@ -137,21 +153,17 @@ def align(self, sequence: Union[str, Iterable[str]], discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD, max_alternative_alignments: int = 1, max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> pd.DataFrame: - json_obj, err = self._json_client.align(sequence, discovery_threshold, - max_alternative_alignments, max_num_nodes_per_seq_char) - - if err: - raise RuntimeError(f"Error while calling the server API {str(err)}") + json_obj = self._json_client.align(sequence, discovery_threshold, + max_alternative_alignments, + max_num_nodes_per_seq_char) return helpers.df_from_align_result(json_obj) - def column_labels(self) -> List[str]: - json_obj, err = self._json_client.column_labels() + return self._json_client.column_labels() - if err: - raise RuntimeError(f"Error while calling the server API {str(err)}") - return json_obj + def ready(self) -> bool: + return self._json_client.ready() class MultiGraphClient: @@ -159,11 +171,9 @@ class MultiGraphClient: def __init__(self): self.graphs = {} - def add_graph(self, host: str, port: int, label: str = None, api_path: str = None) -> None: - if not label: - label = f"{host}:{port}" - - self.graphs[label] = GraphClient(host, port, label, api_path=api_path) + def add_graph(self, host: str, port: int, name: str = None, api_path: str = None) -> None: + graph_client = GraphClient(host, port, name, api_path=api_path) + self.graphs[graph_client.name] = graph_client def list_graphs(self) -> Dict[str, Tuple[str, int]]: return {lbl: (inst.host, inst.port) for (lbl, inst) in @@ -173,15 +183,14 @@ def search(self, sequence: Union[str, Iterable[str]], top_labels: int = DEFAULT_TOP_LABELS, discovery_threshold: float = DEFAULT_DISCOVERY_THRESHOLD, align: bool = False, - max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> \ - Dict[str, pd.DataFrame]: + **align_params) -> Dict[str, pd.DataFrame]: + """See parameters for alignment `align_params` in align()""" result = {} - for label, graph_client in self.graphs.items(): - result[label] = graph_client.search(sequence, top_labels, + for name, graph_client in self.graphs.items(): + result[name] = graph_client.search(sequence, top_labels, discovery_threshold, - align, - max_num_nodes_per_seq_char) + align, **align_params) return result @@ -191,9 +200,9 @@ def align(self, sequence: Union[str, Iterable[str]], max_num_nodes_per_seq_char: float = DEFAULT_NUM_NODES_PER_SEQ_CHAR) -> Dict[ str, pd.DataFrame]: result = {} - for label, graph_client in self.graphs.items(): + for name, graph_client in self.graphs.items(): # TODO: do this async - result[label] = graph_client.align(sequence, discovery_threshold, + result[name] = graph_client.align(sequence, discovery_threshold, max_alternative_alignments, max_num_nodes_per_seq_char) @@ -201,7 +210,7 @@ def align(self, sequence: Union[str, Iterable[str]], def column_labels(self) -> Dict[str, List[str]]: ret = {} - for label, graph_client in self.graphs.items(): - ret[label] = graph_client.column_labels() + for name, graph_client in self.graphs.items(): + ret[name] = graph_client.column_labels() return ret diff --git a/metagraph/api/python/metagraph/helpers.py b/metagraph/api/python/metagraph/helpers.py index 26d858b525..a7e49c6dab 100644 --- a/metagraph/api/python/metagraph/helpers.py +++ b/metagraph/api/python/metagraph/helpers.py @@ -4,6 +4,7 @@ def df_from_search_result(json_res): def _build_dict(row, result): d = dict(row) + # `properties` (optional): dictionary with metadata about the sample if 'properties' in d.keys(): props = d.pop('properties') else: @@ -11,36 +12,24 @@ def _build_dict(row, result): props['seq_description'] = result['seq_description'] - if props['seq_description'] == 'query': - props['seq_description'] = 0 # for consistency - - if 'cigar' in result.keys(): - # we did alignment - props['sequence'] = result['sequence'] - props['score'] = result['score'] - props['cigar'] = result['cigar'] - return {**d, **props} lst = [_build_dict(row, result) for result in json_res for row in result['results']] - if lst: - return pd.DataFrame(lst) - else: - # columns may vary on the graph, so not adding column information - return pd.DataFrame() + # columns of the table may vary on the graph, and inferred automatically + return pd.DataFrame(lst) def df_from_align_result(json_res): # flatten out json result - lst = [alignment for result in json_res for alignment in - result['alignments']] + lst = [(alignment['cigar'], + alignment['score'], + alignment['sequence'], + result['seq_description']) + for result in json_res for alignment in result['alignments']] df = pd.DataFrame(lst, columns=['cigar', 'score', 'sequence', 'seq_description']) - # for consistency, set seq_description to 0 even if we only queried a single sequence - df.loc[df['seq_description'] == 'query', 'seq_description'] = 0 - return df diff --git a/metagraph/integration_tests/test_api.py b/metagraph/integration_tests/test_api.py index 54e8ac1412..29a677e75c 100644 --- a/metagraph/integration_tests/test_api.py +++ b/metagraph/integration_tests/test_api.py @@ -223,8 +223,7 @@ def test_api_simple_query_align_df(self): ret = self.graph_client.search(self.sample_query, discovery_threshold=0.01, align=True) df = ret[self.graph_name] - self.assertIn('cigar', df.columns) - self.assertEqual((self.sample_query_expected_rows, 6), df.shape) + self.assertEqual((self.sample_query_expected_rows, 3), df.shape) def test_api_client_column_labels(self): ret = self.graph_client.column_labels() @@ -274,13 +273,17 @@ def setUpClass(cls): cls.graph_client = GraphClientJson(cls.host, cls.port) + def setUp(self): + if not self.graph_client.ready(): + self.fail("Server takes too long to initialize") + def test_api_align_json(self): - ret, _ = self.graph_client.align("TCGATCGA") + ret = self.graph_client.align("TCGATCGA") self.assertEqual(len(ret), 1) # do various queries def test_api_simple_query(self): - res_list, _ = self.graph_client.search(self.sample_query, discovery_threshold=0.01) + res_list = self.graph_client.search(self.sample_query, discovery_threshold=0.01) self.assertEqual(len(res_list), 1) @@ -297,14 +300,14 @@ def test_api_simple_query(self): def test_api_multiple_queries(self): repetitions = 4 - res_list, _ = self.graph_client.search([self.sample_query] * repetitions) + res_list = self.graph_client.search([self.sample_query] * repetitions) self.assertEqual(len(res_list), repetitions) # testing if the returned query indices range from 0 to n - 1 self.assertEqual(sorted(range(0, repetitions)), sorted([int(a['seq_description']) for a in res_list])) def test_api_stats(self): - res = self.graph_client.stats()[0] + res = self.graph_client.stats() self.assertIn("graph", res.keys()) graph_props = res['graph'] diff --git a/metagraph/src/cli/server.cpp b/metagraph/src/cli/server.cpp index 67892bc845..0f19bf96b6 100644 --- a/metagraph/src/cli/server.cpp +++ b/metagraph/src/cli/server.cpp @@ -290,6 +290,7 @@ std::thread start_server(HttpServer &server_startup, Config &config) { template bool check_data_ready(std::shared_future &data, shared_ptr response) { if (data.wait_for(0s) != std::future_status::ready) { + logger->info("[Server] Got a request during initialization. Asked to come back later"); response->write(SimpleWeb::StatusCode::server_error_service_unavailable, "Server is currently initializing, please come back later."); return false;