Skip to content

Commit

Permalink
feat(scylla): upgrade Scylla driver and add UDT support (#37)
Browse files Browse the repository at this point in the history
This commit upgrades the Scylla driver to version 0.13.1 and adds
support for User Defined Types (UDTs). The `QueryParameter` struct has
been updated to handle UDTs and the `QueryResult` struct now parses UDTs
correctly. The `execute`, `query`, and `batch` methods in
`ScyllaSession` have been updated to handle parameters of UDTs.

The `Uuid` struct has been updated to be cloneable and copyable.

Signed-off-by: Daniel Boll <[email protected]>
  • Loading branch information
Daniel-Boll authored Aug 11, 2024
1 parent 17c7b7b commit 662352e
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 64 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ napi = { version = "2.13.3", default-features = false, features = [
] }
napi-derive = "2.13.0"
tokio = { version = "1", features = ["full"] }
scylla = { version = "0.10.1", features = ["ssl"] }
scylla = { version = "0.13.1", features = [
"ssl",
"full-serialization",
"cloud",
] }
uuid = { version = "1.4.1", features = ["serde", "v4", "fast-rng"] }
serde_json = "1.0"
openssl = { version = "0.10", features = ["vendored"] }
Expand Down
44 changes: 44 additions & 0 deletions examples/udt.mts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { Cluster } from "../index.js";

const nodes = process.env.CLUSTER_NODES?.split(",") ?? ["127.0.0.1:9042"];

console.log(`Connecting to ${nodes}`);

const cluster = new Cluster({ nodes });
const session = await cluster.connect();

await session.execute(
"CREATE KEYSPACE IF NOT EXISTS udt WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }",
);
await session.useKeyspace("udt");

await session.execute(
"CREATE TYPE IF NOT EXISTS address (street text, neighbor text)",
);
await session.execute(
"CREATE TABLE IF NOT EXISTS user (name text, address address, primary key (name))",
);

interface User {
name: string;
address: {
street: string;
neighbor: string;
};
}

const user: User = {
name: "John Doe",
address: {
street: "123 Main St",
neighbor: "Downtown",
},
};

await session.execute("INSERT INTO user (name, address) VALUES (?, ?)", [
user.name,
user.address,
]);

const users = (await session.execute("SELECT * FROM user")) as User[];
console.log(users);
22 changes: 11 additions & 11 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export interface ScyllaMaterializedView {
baseTableName: string
}
export type ScyllaCluster = Cluster
export declare class Cluster {
export class Cluster {
/**
* Object config is in the format:
* {
Expand All @@ -111,7 +111,7 @@ export type ScyllaBatchStatement = BatchStatement
* These statements can be simple or prepared.
* Only INSERT, UPDATE and DELETE statements are allowed.
*/
export declare class BatchStatement {
export class BatchStatement {
constructor()
/**
* Appends a statement to the batch.
Expand All @@ -121,17 +121,17 @@ export declare class BatchStatement {
*/
appendStatement(statement: Query | PreparedStatement): void
}
export declare class PreparedStatement {
export class PreparedStatement {
setConsistency(consistency: Consistency): void
setSerialConsistency(serialConsistency: SerialConsistency): void
}
export declare class Query {
export class Query {
constructor(query: string)
setConsistency(consistency: Consistency): void
setSerialConsistency(serialConsistency: SerialConsistency): void
setPageSize(pageSize: number): void
}
export declare class Metrics {
export class Metrics {
/** Returns counter for nonpaged queries */
getQueriesNum(): bigint
/** Returns counter for pages requested in paged queries */
Expand All @@ -151,11 +151,11 @@ export declare class Metrics {
*/
getLatencyPercentileMs(percentile: number): bigint
}
export declare class ScyllaSession {
export class ScyllaSession {
metrics(): Metrics
getClusterData(): Promise<ScyllaClusterData>
execute(query: string | Query | PreparedStatement, parameters?: Array<number | string | Uuid> | undefined | null): Promise<any>
query(scyllaQuery: Query, parameters?: Array<number | string | Uuid> | undefined | null): Promise<any>
execute(query: string | Query | PreparedStatement, parameters?: Array<number | string | Uuid | Record<string, number | string | Uuid>> | undefined | null): Promise<any>
query(scyllaQuery: Query, parameters?: Array<number | string | Uuid | Record<string, number | string | Uuid>> | undefined | null): Promise<any>
prepare(query: string): Promise<PreparedStatement>
/**
* Perform a batch query\
Expand Down Expand Up @@ -194,7 +194,7 @@ export declare class ScyllaSession {
* console.log(await session.execute("SELECT * FROM users"));
* ```
*/
batch(batch: BatchStatement, parameters: Array<Array<number | string | Uuid> | undefined | null>): Promise<any>
batch(batch: BatchStatement, parameters: Array<Array<number | string | Uuid | Record<string, number | string | Uuid>> | undefined | null>): Promise<any>
/**
* Sends `USE <keyspace_name>` request on all connections\
* This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\
Expand Down Expand Up @@ -264,14 +264,14 @@ export declare class ScyllaSession {
awaitSchemaAgreement(): Promise<Uuid>
checkSchemaAgreement(): Promise<boolean>
}
export declare class ScyllaClusterData {
export class ScyllaClusterData {
/**
* Access keyspaces details collected by the driver Driver collects various schema details like
* tables, partitioners, columns, types. They can be read using this method
*/
getKeyspaceInfo(): Record<string, ScyllaKeyspace> | null
}
export declare class Uuid {
export class Uuid {
/** Generates a random UUID v4. */
static randomV4(): Uuid
/** Parses a UUID from a string. It may fail if the string is not a valid UUID. */
Expand Down
99 changes: 82 additions & 17 deletions src/helpers/query_parameter.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,90 @@
use std::collections::HashMap;

use crate::types::uuid::Uuid;
use napi::bindgen_prelude::Either3;
use scylla::_macro_internal::SerializedValues;
use napi::bindgen_prelude::{Either3, Either4};
use scylla::{
frame::response::result::CqlValue,
serialize::{
row::{RowSerializationContext, SerializeRow},
value::SerializeCql,
RowWriter, SerializationError,
},
};

pub struct QueryParameter {
pub(crate) parameters: Option<Vec<Either3<u32, String, Uuid>>>,
pub struct QueryParameter<'a> {
#[allow(clippy::type_complexity)]
pub(crate) parameters:
Option<Vec<Either4<u32, String, &'a Uuid, HashMap<String, Either3<u32, String, &'a Uuid>>>>>,
}

impl QueryParameter {
pub fn parser(parameters: Option<Vec<Either3<u32, String, &Uuid>>>) -> Option<SerializedValues> {
parameters
.map(|params| {
let mut values = SerializedValues::with_capacity(params.len());
for param in params {
match param {
Either3::A(number) => values.add_value(&(number as i32)).unwrap(),
Either3::B(str) => values.add_value(&str).unwrap(),
Either3::C(uuid) => values.add_value(&(uuid.uuid)).unwrap(),
impl<'a> SerializeRow for QueryParameter<'a> {
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
if let Some(parameters) = &self.parameters {
for (i, parameter) in parameters.iter().enumerate() {
match parameter {
Either4::A(num) => {
CqlValue::Int(*num as i32)
.serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?;
}
Either4::B(str) => {
CqlValue::Text(str.to_string())
.serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?;
}
Either4::C(uuid) => {
CqlValue::Uuid(uuid.get_inner())
.serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?;
}
Either4::D(map) => {
CqlValue::UserDefinedType {
// FIXME: I'm not sure why this is even necessary tho, but if it's and makes sense we'll have to make it so we get the correct info
keyspace: "keyspace".to_string(),
type_name: "type_name".to_string(),
fields: map
.iter()
.map(|(key, value)| match value {
Either3::A(num) => (key.to_string(), Some(CqlValue::Int(*num as i32))),
Either3::B(str) => (key.to_string(), Some(CqlValue::Text(str.to_string()))),
Either3::C(uuid) => (key.to_string(), Some(CqlValue::Uuid(uuid.get_inner()))),
})
.collect::<Vec<(String, Option<CqlValue>)>>(),
}
.serialize(&ctx.columns()[i].typ, writer.make_cell_writer())?;
}
}
values
})
.or(Some(SerializedValues::new()))
}
}
Ok(())
}

fn is_empty(&self) -> bool {
self.parameters.is_none() || self.parameters.as_ref().unwrap().is_empty()
}
}

impl<'a> QueryParameter<'a> {
#[allow(clippy::type_complexity)]
pub fn parser(
parameters: Option<
Vec<Either4<u32, String, &'a Uuid, HashMap<String, Either3<u32, String, &'a Uuid>>>>,
>,
) -> Option<Self> {
if parameters.is_none() {
return Some(QueryParameter { parameters: None });
}

let parameters = parameters.unwrap();

let mut params = Vec::with_capacity(parameters.len());
for parameter in parameters {
params.push(parameter);
}

Some(QueryParameter {
parameters: Some(params),
})
}
}
75 changes: 45 additions & 30 deletions src/helpers/query_results.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,73 @@
use scylla::frame::response::result::ColumnType;
use scylla::frame::response::result::{ColumnType, CqlValue};
pub struct QueryResult {
pub(crate) result: scylla::QueryResult,
}

impl QueryResult {
pub fn parser(result: scylla::QueryResult) -> serde_json::Value {
if result.result_not_rows().is_ok() {
return serde_json::json!([]);
}

if result.rows.is_none() {
if result.result_not_rows().is_ok() || result.rows.is_none() {
return serde_json::json!([]);
}

let rows = result.rows.unwrap();
let column_specs = result.col_specs;

let mut result = serde_json::json!([]);
let mut result_json = serde_json::json!([]);

for row in rows {
let mut row_object = serde_json::Map::new();

for (i, column) in row.columns.iter().enumerate() {
let column_name = column_specs[i].name.clone();

let column_value = match column {
Some(column) => match column_specs[i].typ {
ColumnType::Ascii => serde_json::Value::String(column.as_ascii().unwrap().to_string()),
ColumnType::Text => serde_json::Value::String(column.as_text().unwrap().to_string()),
ColumnType::Uuid => serde_json::Value::String(column.as_uuid().unwrap().to_string()),
ColumnType::Int => serde_json::Value::Number(
serde_json::Number::from_f64(column.as_int().unwrap() as f64).unwrap(),
),
ColumnType::Float => serde_json::Value::Number(
serde_json::Number::from_f64(column.as_float().unwrap() as f64).unwrap(),
),
ColumnType::Timestamp => {
serde_json::Value::String(column.as_date().unwrap().to_string())
}
ColumnType::Date => serde_json::Value::String(column.as_date().unwrap().to_string()),
_ => "Not implemented".into(),
},
None => serde_json::Value::Null,
};

let column_value = Self::parse_value(column, &column_specs[i].typ);
row_object.insert(column_name, column_value);
}

result
result_json
.as_array_mut()
.unwrap()
.push(serde_json::Value::Object(row_object));
}

result
result_json
}

fn parse_value(column: &Option<CqlValue>, column_type: &ColumnType) -> serde_json::Value {
match column {
Some(column) => match column_type {
ColumnType::Ascii => serde_json::Value::String(column.as_ascii().unwrap().to_string()),
ColumnType::Text => serde_json::Value::String(column.as_text().unwrap().to_string()),
ColumnType::Uuid => serde_json::Value::String(column.as_uuid().unwrap().to_string()),
ColumnType::Int => serde_json::Value::Number(
serde_json::Number::from_f64(column.as_int().unwrap() as f64).unwrap(),
),
ColumnType::Float => serde_json::Value::Number(
serde_json::Number::from_f64(column.as_float().unwrap() as f64).unwrap(),
),
ColumnType::Timestamp | ColumnType::Date => {
serde_json::Value::String(column.as_cql_date().unwrap().0.to_string())
}
ColumnType::UserDefinedType { field_types, .. } => {
Self::parse_udt(column.as_udt().unwrap(), field_types)
}
_ => "ColumnType currently not implemented".into(),
},
None => serde_json::Value::Null,
}
}

fn parse_udt(
udt: &[(String, Option<CqlValue>)],
field_types: &[(String, ColumnType)],
) -> serde_json::Value {
let mut result = serde_json::Map::new();

for (i, (field_name, field_value)) in udt.iter().enumerate() {
let field_type = &field_types[i].1;
let parsed_value = Self::parse_value(field_value, field_type);
result.insert(field_name.clone(), parsed_value);
}

serde_json::Value::Object(result)
}
}
Loading

0 comments on commit 662352e

Please sign in to comment.