diff --git a/llm/cli.py b/llm/cli.py index 966c2e47..9b0109fa 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -716,7 +716,8 @@ def logs_list( if conversation_id: where_bits.append("responses.conversation_id = :conversation_id") if where_bits: - sql_format["extra_where"] = " where " + " and ".join(where_bits) + where_ = " and " if query else " where " + sql_format["extra_where"] = where_ + " and ".join(where_bits) final_sql = sql.format(**sql_format) rows = list( diff --git a/tests/test_llm.py b/tests/test_llm.py index 795c4146..e86a44e9 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -146,16 +146,19 @@ def test_logs_filtered(user_path, model): @pytest.mark.parametrize( - "query,expected", + "query,extra_args,expected", ( # With no search term order should be by datetime - ("", ["doc1", "doc2", "doc3"]), + ("", [], ["doc1", "doc2", "doc3"]), # With a search it's order by rank instead - ("llama", ["doc1", "doc3"]), - ("alpaca", ["doc2"]), + ("llama", [], ["doc1", "doc3"]), + ("alpaca", [], ["doc2"]), + # Model filter should work too + ("llama", ["-m", "davinci"], ["doc1", "doc3"]), + ("llama", ["-m", "davinci2"], []), ), ) -def test_logs_search(user_path, query, expected): +def test_logs_search(user_path, query, extra_args, expected): log_path = str(user_path / "logs.db") db = sqlite_utils.Database(log_path) migrate(db) @@ -175,7 +178,7 @@ def _insert(id, text): _insert("doc2", "alpaca") _insert("doc3", "llama llama") runner = CliRunner() - result = runner.invoke(cli, ["logs", "list", "-q", query, "--json"]) + result = runner.invoke(cli, ["logs", "list", "-q", query, "--json"] + extra_args) assert result.exit_code == 0 records = json.loads(result.output.strip()) assert [record["id"] for record in records] == expected