Skip to content

Commit

Permalink
Pretty sure this is wrong (#1589)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored and Narsil committed Aug 7, 2024
1 parent 18fad02 commit 1c19531
Showing 1 changed file with 151 additions and 1 deletion.
152 changes: 151 additions & 1 deletion bindings/python/src/utils/serde_pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ use serde::de::value::Error;
use serde::{ser, Serialize};
type Result<T> = ::std::result::Result<T, Error>;

const MAX_DEPTH: usize = 5;

pub struct Serializer {
// This string starts empty and JSON is appended as values are serialized.
output: String,
level: usize,
}

// By convention, the public API of a Serde serializer is one or more `to_abc`
Expand All @@ -18,6 +21,7 @@ where
{
let mut serializer = Serializer {
output: String::new(),
level: 0,
};
value.serialize(&mut serializer)?;
Ok(serializer.output)
Expand Down Expand Up @@ -51,6 +55,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// of the primitive types of the data model and map it to JSON by appending
// into the output string.
fn serialize_bool(self, v: bool) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += if v { "True" } else { "False" };
Ok(())
}
Expand All @@ -74,6 +83,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// Not particularly efficient but this is example code anyway. A more
// performant approach would be to use the `itoa` crate.
fn serialize_i64(self, v: i64) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += &v.to_string();
Ok(())
}
Expand All @@ -91,6 +105,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
}

fn serialize_u64(self, v: u64) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += &v.to_string();
Ok(())
}
Expand All @@ -100,6 +119,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
}

fn serialize_f64(self, v: f64) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += &v.to_string();
Ok(())
}
Expand All @@ -114,6 +138,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// get the idea. For example it would emit invalid JSON if the input string
// contains a '"' character.
fn serialize_str(self, v: &str) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "\"";
self.output += v;
self.output += "\"";
Expand Down Expand Up @@ -152,6 +181,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// In Serde, unit means an anonymous value containing no data. Map this to
// JSON as `null`.
fn serialize_unit(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "None";
Ok(())
}
Expand All @@ -173,6 +207,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
_variant_index: u32,
variant: &'static str,
) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
// self.serialize_str(variant)
self.output += variant;
Ok(())
Expand Down Expand Up @@ -202,6 +241,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
// variant.serialize(&mut *self)?;
self.output += variant;
self.output += "(";
Expand All @@ -221,6 +265,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// explicitly in the serialized form. Some serializers may only be able to
// support sequences for which the length is known up front.
fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
self.output += "[";
Ok(self)
}
Expand All @@ -230,6 +279,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// means that the corresponding `Deserialize implementation will know the
// length without needing to look at the serialized data.
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
self.output += "(";
Ok(self)
}
Expand All @@ -252,6 +306,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
// variant.serialize(&mut *self)?;
self.output += variant;
self.output += "(";
Expand All @@ -260,6 +319,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {

// Maps are represented in JSON as `{ K: V, K: V, ... }`.
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
println!("Serialize map");
self.output += "{";
Ok(self)
Expand All @@ -271,6 +335,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// Deserialize implementation is required to know what the keys are without
// looking at the serialized data.
fn serialize_struct(self, name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
// self.serialize_map(Some(len))
// name.serialize(&mut *self)?;
if let Some(stripped) = name.strip_suffix("Helper") {
Expand All @@ -291,6 +360,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
// variant.serialize(&mut *self)?;
self.output += variant;
self.output += "(";
Expand All @@ -316,6 +390,11 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('[') {
self.output += ", ";
}
Expand All @@ -324,6 +403,11 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer {

// Close the sequence.
fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "]";
Ok(())
}
Expand All @@ -338,13 +422,23 @@ impl<'a> ser::SerializeTuple for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand All @@ -359,13 +453,23 @@ impl<'a> ser::SerializeTupleStruct for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand All @@ -388,13 +492,23 @@ impl<'a> ser::SerializeTupleVariant for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand Down Expand Up @@ -424,6 +538,11 @@ impl<'a> ser::SerializeMap for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('{') {
self.output += ", ";
}
Expand All @@ -437,11 +556,21 @@ impl<'a> ser::SerializeMap for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ":";
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "}";
Ok(())
}
Expand All @@ -457,6 +586,11 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
Expand All @@ -471,6 +605,11 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer {
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand All @@ -486,6 +625,11 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
Expand All @@ -496,6 +640,11 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer {
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand Down Expand Up @@ -525,7 +674,7 @@ fn test_struct() {
let expected = r#"Test(int=1, seq=["a", "b"])"#;
assert_eq!(to_string(&test).unwrap(), expected);
}

/*
#[test]
fn test_enum() {
#[derive(Serialize)]
Expand Down Expand Up @@ -657,3 +806,4 @@ fn test_flatten() {
let expected = r#"A(a=True, b=1)"#;
assert_eq!(to_string(&u).unwrap(), expected);
}
*/

0 comments on commit 1c19531

Please sign in to comment.