Skip to content

Commit

Permalink
feat: plugin comms interface can handle multiple active sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
j-lanson committed Aug 23, 2024
1 parent 79414b1 commit cbd4d48
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 57 additions & 23 deletions hipcheck/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::Result;
use futures::future::join_all;
use serde_json::Value;
use std::collections::HashMap;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Mutex};

pub fn dummy() {
let plugin = Plugin {
Expand Down Expand Up @@ -41,9 +41,58 @@ pub async fn initialize_plugins(
Ok(out)
}

struct ActivePlugin {
next_id: Mutex<usize>,
channel: PluginTransport,
}
impl ActivePlugin {
pub fn new(channel: PluginTransport) -> Self {
ActivePlugin {
next_id: Mutex::new(0),
channel,
}
}
async fn get_unique_id(&self) -> usize {
let mut id_lock = self.next_id.lock().await;
let res: usize = *id_lock;
*id_lock += 2;
drop(id_lock);
res
}
pub async fn query(&self, name: String, key: Value) -> Result<PluginResponse> {
let id = self.get_unique_id().await;
let query = Query {
id,
request: true,
publisher: "".to_owned(),
plugin: self.channel.name().to_owned(),
query: name,
key,
output: serde_json::json!(null),
};
Ok(self.channel.query(query).await?.into())
}
pub async fn resume_query(
&self,
state: AwaitingResult,
output: Value,
) -> Result<PluginResponse> {
let query = Query {
id: state.id,
request: false,
publisher: state.publisher,
plugin: state.plugin,
query: state.query,
key: serde_json::json!(null),
output,
};
Ok(self.channel.query(query).await?.into())
}
}

pub struct HcPluginCore {
executor: PluginExecutor,
plugins: HashMap<String, PluginTransport>,
plugins: HashMap<String, ActivePlugin>,
}
impl HcPluginCore {
// When this object is returned, the plugins are all connected but the
Expand All @@ -69,36 +118,21 @@ impl HcPluginCore {
})
.collect();
// Use configs to initialize corresponding plugin
let plugins = HashMap::<String, PluginTransport>::from_iter(
let plugins = HashMap::<String, ActivePlugin>::from_iter(
initialize_plugins(mapped_ctxs)
.await?
.into_iter()
.map(|p| (p.name().to_owned(), p)),
.map(|p| (p.name().to_owned(), ActivePlugin::new(p))),
);
// Now we have a set of started and initialized plugins to interact with
Ok(HcPluginCore { executor, plugins })
}
// @Temporary
pub async fn run(&mut self) -> Result<()> {
let channel = self.plugins.get_mut("rand_data").unwrap();
match channel
.send(Query {
id: 1,
request: true,
publisher: "".to_owned(),
plugin: "".to_owned(),
query: "rand_data".to_owned(),
key: serde_json::json!(7),
output: serde_json::json!(null),
})
.await
{
Ok(q) => q,
Err(e) => {
println!("Failed: {e}");
}
};
let resp = channel.recv().await?;
let handle = self.plugins.get("rand_data").unwrap();
let resp = handle
.query("rand_data".to_owned(), serde_json::json!(7))
.await?;
println!("Plugin response: {resp:?}");
Ok(())
}
Expand Down
154 changes: 132 additions & 22 deletions hipcheck/src/plugin/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use crate::hipcheck::{
};
use crate::{hc_error, Error, Result, StdResult};
use serde_json::Value;
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::convert::TryFrom;
use std::ops::Not;
use std::process::Child;
use tokio::sync::{mpsc, Mutex};
use tonic::codec::Streaming;
use tonic::transport::Channel;

Expand Down Expand Up @@ -161,7 +162,7 @@ impl PluginContext {
}
pub async fn initiate_query_protocol(
&mut self,
mut rx: tokio::sync::mpsc::Receiver<PluginQuery>,
mut rx: mpsc::Receiver<PluginQuery>,
) -> Result<Streaming<PluginQuery>> {
let stream = async_stream::stream! {
while let Some(item) = rx.recv().await {
Expand All @@ -185,8 +186,9 @@ impl PluginContext {
);
self.set_configuration(&config).await?.as_result()?;
let default_policy_expr = self.get_default_policy_expression().await?;
let (tx, mut out_rx) = tokio::sync::mpsc::channel::<PluginQuery>(10);
let (tx, mut out_rx) = mpsc::channel::<PluginQuery>(10);
let rx = self.initiate_query_protocol(out_rx).await?;
let rx = Mutex::new(MultiplexedQueryReceiver::new(rx));
Ok(PluginTransport {
schemas,
default_policy_expr,
Expand Down Expand Up @@ -271,46 +273,103 @@ impl TryFrom<Query> for PluginQuery {
}
}

pub struct MultiplexedQueryReceiver {
rx: Streaming<PluginQuery>,
backlog: HashMap<i32, VecDeque<PluginQuery>>,
}
impl MultiplexedQueryReceiver {
pub fn new(rx: Streaming<PluginQuery>) -> Self {
Self {
rx,
backlog: HashMap::new(),
}
}
// @Invariant - this function will never return an empty VecDeque
pub async fn recv(&mut self, id: i32) -> Result<Option<VecDeque<PluginQuery>>> {
// If we have 1+ messages on backlog for `id`, return them all,
// no need to waste time with successive calls
if let Some(msgs) = self.backlog.remove(&id) {
return Ok(Some(msgs));
}
// No backlog message, need to operate the receiver
loop {
let Some(raw) = self.rx.message().await? else {
// gRPC channel was closed
return Ok(None);
};
let raw_id = raw.id;
if raw_id == id {
return Ok(Some(VecDeque::from([raw])));
}
match self.backlog.get_mut(&raw_id) {
Some(vec) => {
vec.push_back(raw);
}
None => {
self.backlog.insert(raw_id, VecDeque::from([raw]));
}
}
}
}
}

// Encapsulate an "initialized" state of a Plugin with interfaces that abstract
// query chunking to produce whole messages for the Hipcheck engine
pub struct PluginTransport {
pub schemas: HashMap<String, Schema>,
pub default_policy_expr: String, // TODO - update with policy_expr type
ctx: PluginContext,
tx: tokio::sync::mpsc::Sender<PluginQuery>,
rx: Streaming<PluginQuery>,
tx: mpsc::Sender<PluginQuery>,
rx: Mutex<MultiplexedQueryReceiver>,
}
impl PluginTransport {
pub fn name(&self) -> &str {
&self.ctx.plugin.name
}
pub async fn send(&mut self, query: Query) -> Result<()> {
pub async fn query(&self, query: Query) -> Result<Option<Query>> {
use QueryState::*;

// Send the query
let query: PluginQuery = query.try_into()?;
eprintln!("Sending query: {query:?}");
let id = query.id;
self.tx
.send(query)
.await
.map_err(|e| hc_error!("sending query failed: {}", e))
}
pub async fn recv(&mut self) -> Result<Option<Query>> {
use QueryState::*;
let Some(mut raw) = self.rx.message().await? else {
// gRPC channel was closed
.map_err(|e| hc_error!("sending query failed: {}", e))?;

// Get initial response batch
let mut rx_handle = self.rx.lock().await;
let Some(mut msg_chunks) = rx_handle.recv(id).await? else {
return Ok(None);
};
drop(rx_handle);

let mut raw = msg_chunks.pop_front().unwrap();
let mut state: QueryState = raw.state.try_into()?;
// As long as we expect successive chunks, keep receiving

// If response is the first of a set of chunks, handle
if matches!(state, QueryReplyInProgress) {
while matches!(state, QueryReplyInProgress) {
let Some(next) = self.rx.message().await? else {
return Err(hc_error!(
"plugin gRPC channel closed while sending chunked message"
));
// We expect another message. Pull it off the existing queue,
// or get a new one if we have run out
let next = match msg_chunks.pop_front() {
Some(msg) => msg,
None => {
// We ran out of messages, get a new batch
let mut rx_handle = self.rx.lock().await;
match rx_handle.recv(id).await? {
Some(x) => {
drop(rx_handle);
msg_chunks = x;
}
None => {
return Ok(None);
}
};
msg_chunks.pop_front().unwrap()
}
};
// Assert that the ids are consistent
if next.id != raw.id {
return Err(hc_error!("msg ids from plugin do not match"));
}
// By now we have our "next" message
state = next.state.try_into()?;
match state {
QueryUnspecified => return Err(hc_error!("unspecified error from plugin")),
Expand All @@ -324,6 +383,13 @@ impl PluginTransport {
}
};
}
// Sanity check - after we've left this loop, there should be no left over message
if !msg_chunks.is_empty() {
return Err(hc_error!(
"received additional messages for id '{}' after QueryComplete status message",
id
));
}
}
raw.try_into().map(Some)
}
Expand All @@ -341,3 +407,47 @@ impl From<PluginContextWithConfig> for (PluginContext, Value) {
(value.0, value.1)
}
}

#[derive(Clone, Debug)]
pub struct AwaitingResult {
pub id: usize,
pub publisher: String,
pub plugin: String,
pub query: String,
pub key: Value,
}
impl From<Query> for AwaitingResult {
fn from(value: Query) -> Self {
AwaitingResult {
id: value.id,
publisher: value.publisher,
plugin: value.plugin,
query: value.query,
key: value.key,
}
}
}

#[derive(Clone, Debug)]
pub enum PluginResponse {
RemoteClosed,
AwaitingResult(AwaitingResult),
Completed(Value),
}
impl From<Option<Query>> for PluginResponse {
fn from(value: Option<Query>) -> Self {
match value {
Some(q) => q.into(),
None => PluginResponse::RemoteClosed,
}
}
}
impl From<Query> for PluginResponse {
fn from(value: Query) -> Self {
if !value.request {
PluginResponse::Completed(value.output)
} else {
PluginResponse::AwaitingResult(value.into())
}
}
}

0 comments on commit cbd4d48

Please sign in to comment.