From 662352e8036d6bd596643bb1cbc036fe7f9b061a Mon Sep 17 00:00:00 2001 From: Daniel Boll <43689101+Daniel-Boll@users.noreply.github.com> Date: Sat, 10 Aug 2024 21:56:49 -0300 Subject: [PATCH] feat(scylla): upgrade Scylla driver and add UDT support (#37) This commit upgrades the Scylla driver to version 0.13.1 and adds support for User Defined Types (UDTs). The `QueryParameter` struct has been updated to handle UDTs and the `QueryResult` struct now parses UDTs correctly. The `execute`, `query`, and `batch` methods in `ScyllaSession` have been updated to handle parameters of UDTs. The `Uuid` struct has been updated to be cloneable and copyable. Signed-off-by: Daniel Boll --- Cargo.toml | 6 ++- examples/udt.mts | 44 +++++++++++++++ index.d.ts | 22 ++++---- src/helpers/query_parameter.rs | 99 ++++++++++++++++++++++++++++------ src/helpers/query_results.rs | 75 +++++++++++++++----------- src/session/scylla_session.rs | 18 +++++-- src/types/uuid.rs | 8 ++- 7 files changed, 208 insertions(+), 64 deletions(-) create mode 100644 examples/udt.mts diff --git a/Cargo.toml b/Cargo.toml index aca481f..341ab60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,11 @@ napi = { version = "2.13.3", default-features = false, features = [ ] } napi-derive = "2.13.0" tokio = { version = "1", features = ["full"] } -scylla = { version = "0.10.1", features = ["ssl"] } +scylla = { version = "0.13.1", features = [ + "ssl", + "full-serialization", + "cloud", +] } uuid = { version = "1.4.1", features = ["serde", "v4", "fast-rng"] } serde_json = "1.0" openssl = { version = "0.10", features = ["vendored"] } diff --git a/examples/udt.mts b/examples/udt.mts new file mode 100644 index 0000000..3cefc1d --- /dev/null +++ b/examples/udt.mts @@ -0,0 +1,44 @@ +import { Cluster } from "../index.js"; + +const nodes = process.env.CLUSTER_NODES?.split(",") ?? ["127.0.0.1:9042"]; + +console.log(`Connecting to ${nodes}`); + +const cluster = new Cluster({ nodes }); +const session = await cluster.connect(); + +await session.execute( + "CREATE KEYSPACE IF NOT EXISTS udt WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }", +); +await session.useKeyspace("udt"); + +await session.execute( + "CREATE TYPE IF NOT EXISTS address (street text, neighbor text)", +); +await session.execute( + "CREATE TABLE IF NOT EXISTS user (name text, address address, primary key (name))", +); + +interface User { + name: string; + address: { + street: string; + neighbor: string; + }; +} + +const user: User = { + name: "John Doe", + address: { + street: "123 Main St", + neighbor: "Downtown", + }, +}; + +await session.execute("INSERT INTO user (name, address) VALUES (?, ?)", [ + user.name, + user.address, +]); + +const users = (await session.execute("SELECT * FROM user")) as User[]; +console.log(users); diff --git a/index.d.ts b/index.d.ts index 654479a..7ec3ef8 100644 --- a/index.d.ts +++ b/index.d.ts @@ -92,7 +92,7 @@ export interface ScyllaMaterializedView { baseTableName: string } export type ScyllaCluster = Cluster -export declare class Cluster { +export class Cluster { /** * Object config is in the format: * { @@ -111,7 +111,7 @@ export type ScyllaBatchStatement = BatchStatement * These statements can be simple or prepared. * Only INSERT, UPDATE and DELETE statements are allowed. */ -export declare class BatchStatement { +export class BatchStatement { constructor() /** * Appends a statement to the batch. @@ -121,17 +121,17 @@ export declare class BatchStatement { */ appendStatement(statement: Query | PreparedStatement): void } -export declare class PreparedStatement { +export class PreparedStatement { setConsistency(consistency: Consistency): void setSerialConsistency(serialConsistency: SerialConsistency): void } -export declare class Query { +export class Query { constructor(query: string) setConsistency(consistency: Consistency): void setSerialConsistency(serialConsistency: SerialConsistency): void setPageSize(pageSize: number): void } -export declare class Metrics { +export class Metrics { /** Returns counter for nonpaged queries */ getQueriesNum(): bigint /** Returns counter for pages requested in paged queries */ @@ -151,11 +151,11 @@ export declare class Metrics { */ getLatencyPercentileMs(percentile: number): bigint } -export declare class ScyllaSession { +export class ScyllaSession { metrics(): Metrics getClusterData(): Promise - execute(query: string | Query | PreparedStatement, parameters?: Array | undefined | null): Promise - query(scyllaQuery: Query, parameters?: Array | undefined | null): Promise + execute(query: string | Query | PreparedStatement, parameters?: Array> | undefined | null): Promise + query(scyllaQuery: Query, parameters?: Array> | undefined | null): Promise prepare(query: string): Promise /** * Perform a batch query\ @@ -194,7 +194,7 @@ export declare class ScyllaSession { * console.log(await session.execute("SELECT * FROM users")); * ``` */ - batch(batch: BatchStatement, parameters: Array | undefined | null>): Promise + batch(batch: BatchStatement, parameters: Array> | undefined | null>): Promise /** * Sends `USE ` request on all connections\ * This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\ @@ -264,14 +264,14 @@ export declare class ScyllaSession { awaitSchemaAgreement(): Promise checkSchemaAgreement(): Promise } -export declare class ScyllaClusterData { +export class ScyllaClusterData { /** * Access keyspaces details collected by the driver Driver collects various schema details like * tables, partitioners, columns, types. They can be read using this method */ getKeyspaceInfo(): Record | null } -export declare class Uuid { +export class Uuid { /** Generates a random UUID v4. */ static randomV4(): Uuid /** Parses a UUID from a string. It may fail if the string is not a valid UUID. */ diff --git a/src/helpers/query_parameter.rs b/src/helpers/query_parameter.rs index df6e36e..f347e1c 100644 --- a/src/helpers/query_parameter.rs +++ b/src/helpers/query_parameter.rs @@ -1,25 +1,90 @@ +use std::collections::HashMap; + use crate::types::uuid::Uuid; -use napi::bindgen_prelude::Either3; -use scylla::_macro_internal::SerializedValues; +use napi::bindgen_prelude::{Either3, Either4}; +use scylla::{ + frame::response::result::CqlValue, + serialize::{ + row::{RowSerializationContext, SerializeRow}, + value::SerializeCql, + RowWriter, SerializationError, + }, +}; -pub struct QueryParameter { - pub(crate) parameters: Option>>, +pub struct QueryParameter<'a> { + #[allow(clippy::type_complexity)] + pub(crate) parameters: + Option>>>>, } -impl QueryParameter { - pub fn parser(parameters: Option>>) -> Option { - parameters - .map(|params| { - let mut values = SerializedValues::with_capacity(params.len()); - for param in params { - match param { - Either3::A(number) => values.add_value(&(number as i32)).unwrap(), - Either3::B(str) => values.add_value(&str).unwrap(), - Either3::C(uuid) => values.add_value(&(uuid.uuid)).unwrap(), +impl<'a> SerializeRow for QueryParameter<'a> { + fn serialize( + &self, + ctx: &RowSerializationContext<'_>, + writer: &mut RowWriter, + ) -> Result<(), SerializationError> { + if let Some(parameters) = &self.parameters { + for (i, parameter) in parameters.iter().enumerate() { + match parameter { + Either4::A(num) => { + CqlValue::Int(*num as i32) + .serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?; + } + Either4::B(str) => { + CqlValue::Text(str.to_string()) + .serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?; + } + Either4::C(uuid) => { + CqlValue::Uuid(uuid.get_inner()) + .serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?; + } + Either4::D(map) => { + CqlValue::UserDefinedType { + // FIXME: I'm not sure why this is even necessary tho, but if it's and makes sense we'll have to make it so we get the correct info + keyspace: "keyspace".to_string(), + type_name: "type_name".to_string(), + fields: map + .iter() + .map(|(key, value)| match value { + Either3::A(num) => (key.to_string(), Some(CqlValue::Int(*num as i32))), + Either3::B(str) => (key.to_string(), Some(CqlValue::Text(str.to_string()))), + Either3::C(uuid) => (key.to_string(), Some(CqlValue::Uuid(uuid.get_inner()))), + }) + .collect::)>>(), + } + .serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?; } } - values - }) - .or(Some(SerializedValues::new())) + } + } + Ok(()) + } + + fn is_empty(&self) -> bool { + self.parameters.is_none() || self.parameters.as_ref().unwrap().is_empty() + } +} + +impl<'a> QueryParameter<'a> { + #[allow(clippy::type_complexity)] + pub fn parser( + parameters: Option< + Vec>>>, + >, + ) -> Option { + if parameters.is_none() { + return Some(QueryParameter { parameters: None }); + } + + let parameters = parameters.unwrap(); + + let mut params = Vec::with_capacity(parameters.len()); + for parameter in parameters { + params.push(parameter); + } + + Some(QueryParameter { + parameters: Some(params), + }) } } diff --git a/src/helpers/query_results.rs b/src/helpers/query_results.rs index cb7e484..c4863e9 100644 --- a/src/helpers/query_results.rs +++ b/src/helpers/query_results.rs @@ -1,58 +1,73 @@ -use scylla::frame::response::result::ColumnType; +use scylla::frame::response::result::{ColumnType, CqlValue}; pub struct QueryResult { pub(crate) result: scylla::QueryResult, } impl QueryResult { pub fn parser(result: scylla::QueryResult) -> serde_json::Value { - if result.result_not_rows().is_ok() { - return serde_json::json!([]); - } - - if result.rows.is_none() { + if result.result_not_rows().is_ok() || result.rows.is_none() { return serde_json::json!([]); } let rows = result.rows.unwrap(); let column_specs = result.col_specs; - let mut result = serde_json::json!([]); + let mut result_json = serde_json::json!([]); for row in rows { let mut row_object = serde_json::Map::new(); for (i, column) in row.columns.iter().enumerate() { let column_name = column_specs[i].name.clone(); - - let column_value = match column { - Some(column) => match column_specs[i].typ { - ColumnType::Ascii => serde_json::Value::String(column.as_ascii().unwrap().to_string()), - ColumnType::Text => serde_json::Value::String(column.as_text().unwrap().to_string()), - ColumnType::Uuid => serde_json::Value::String(column.as_uuid().unwrap().to_string()), - ColumnType::Int => serde_json::Value::Number( - serde_json::Number::from_f64(column.as_int().unwrap() as f64).unwrap(), - ), - ColumnType::Float => serde_json::Value::Number( - serde_json::Number::from_f64(column.as_float().unwrap() as f64).unwrap(), - ), - ColumnType::Timestamp => { - serde_json::Value::String(column.as_date().unwrap().to_string()) - } - ColumnType::Date => serde_json::Value::String(column.as_date().unwrap().to_string()), - _ => "Not implemented".into(), - }, - None => serde_json::Value::Null, - }; - + let column_value = Self::parse_value(column, &column_specs[i].typ); row_object.insert(column_name, column_value); } - result + result_json .as_array_mut() .unwrap() .push(serde_json::Value::Object(row_object)); } - result + result_json + } + + fn parse_value(column: &Option, column_type: &ColumnType) -> serde_json::Value { + match column { + Some(column) => match column_type { + ColumnType::Ascii => serde_json::Value::String(column.as_ascii().unwrap().to_string()), + ColumnType::Text => serde_json::Value::String(column.as_text().unwrap().to_string()), + ColumnType::Uuid => serde_json::Value::String(column.as_uuid().unwrap().to_string()), + ColumnType::Int => serde_json::Value::Number( + serde_json::Number::from_f64(column.as_int().unwrap() as f64).unwrap(), + ), + ColumnType::Float => serde_json::Value::Number( + serde_json::Number::from_f64(column.as_float().unwrap() as f64).unwrap(), + ), + ColumnType::Timestamp | ColumnType::Date => { + serde_json::Value::String(column.as_cql_date().unwrap().0.to_string()) + } + ColumnType::UserDefinedType { field_types, .. } => { + Self::parse_udt(column.as_udt().unwrap(), field_types) + } + _ => "ColumnType currently not implemented".into(), + }, + None => serde_json::Value::Null, + } + } + + fn parse_udt( + udt: &[(String, Option)], + field_types: &[(String, ColumnType)], + ) -> serde_json::Value { + let mut result = serde_json::Map::new(); + + for (i, (field_name, field_value)) in udt.iter().enumerate() { + let field_type = &field_types[i].1; + let parsed_value = Self::parse_value(field_value, field_type); + result.insert(field_name.clone(), parsed_value); + } + + serde_json::Value::Object(result) } } diff --git a/src/session/scylla_session.rs b/src/session/scylla_session.rs index 1969da2..5ad8fb0 100644 --- a/src/session/scylla_session.rs +++ b/src/session/scylla_session.rs @@ -1,10 +1,12 @@ +use std::collections::HashMap; + use crate::helpers::query_parameter::QueryParameter; use crate::helpers::query_results::QueryResult; use crate::query::batch_statement::ScyllaBatchStatement; use crate::query::scylla_prepared_statement::PreparedStatement; use crate::query::scylla_query::Query; use crate::types::uuid::Uuid; -use napi::bindgen_prelude::Either3; +use napi::bindgen_prelude::{Either3, Either4}; use super::metrics; use super::topology::ScyllaClusterData; @@ -37,11 +39,14 @@ impl ScyllaSession { cluster_data.into() } + #[allow(clippy::type_complexity)] #[napi] pub async fn execute( &self, query: Either3, - parameters: Option>>, + parameters: Option< + Vec>>>, + >, ) -> napi::Result { let values = QueryParameter::parser(parameters.clone()).ok_or(napi::Error::new( napi::Status::InvalidArg, @@ -69,11 +74,14 @@ impl ScyllaSession { Ok(QueryResult::parser(query_result)) } + #[allow(clippy::type_complexity)] #[napi] pub async fn query( &self, scylla_query: &Query, - parameters: Option>>, + parameters: Option< + Vec>>>, + >, ) -> napi::Result { let values = QueryParameter::parser(parameters.clone()).ok_or(napi::Error::new( napi::Status::InvalidArg, @@ -146,7 +154,9 @@ impl ScyllaSession { pub async fn batch( &self, batch: &ScyllaBatchStatement, - parameters: Vec>>>, + parameters: Vec< + Option>>>>, + >, ) -> napi::Result { let values = parameters .iter() diff --git a/src/types/uuid.rs b/src/types/uuid.rs index 3a0442c..abe453b 100644 --- a/src/types/uuid.rs +++ b/src/types/uuid.rs @@ -1,7 +1,7 @@ use napi::Result; #[napi()] -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub struct Uuid { pub(crate) uuid: uuid::Uuid, } @@ -18,6 +18,12 @@ impl From for uuid::Uuid { } } +impl Uuid { + pub(crate) fn get_inner(&self) -> uuid::Uuid { + self.uuid + } +} + #[napi] impl Uuid { /// Generates a random UUID v4.