diff --git a/collection.go b/collection.go index c11109a..8868b43 100644 --- a/collection.go +++ b/collection.go @@ -5,6 +5,7 @@ package cmw import ( "encoding/json" + "errors" "fmt" "github.com/fxamacker/cbor/v2" @@ -14,6 +15,39 @@ type Collection struct { m map[any]CMW } +type CollectionSerialization uint + +const ( + UnknownCollectionSerialization = CollectionSerialization(iota) + CollectionSerializationJSON + CollectionSerializationCBOR +) + +func (o *Collection) Deserialize(b []byte) error { + switch b[0] { + case 0x7b: // '{' + return o.UnmarshalJSON(b) + default: + return o.UnmarshalCBOR(b) + } +} + +func (o *Collection) Serialize() ([]byte, error) { + s, err := o.detectSerialization() + if err != nil { + return nil, err + } + + switch s { + case CollectionSerializationCBOR: + return o.MarshalCBOR() + case CollectionSerializationJSON: + return o.MarshalJSON() + default: + return nil, errors.New("unsupported serialization") + } +} + // GetMap returns a pointer to the internal map func (o *Collection) GetMap() map[any]CMW { return o.m @@ -129,3 +163,32 @@ func (o *Collection) UnmarshalJSON(b []byte) error { return nil } + +func (o Collection) detectSerialization() (CollectionSerialization, error) { + rec := make(map[CollectionSerialization]bool) + + s := UnknownCollectionSerialization + + for k, v := range o.m { + switch v.serialization { + case CBORArray, CBORTag: + s = CollectionSerializationCBOR + rec[s] = true + case JSONArray: + s = CollectionSerializationJSON + rec[s] = true + default: + return UnknownCollectionSerialization, + fmt.Errorf( + "serialization not defined for collection item with k %v", k, + ) + } + } + + if len(rec) != 1 { + return UnknownCollectionSerialization, + errors.New("CMW collection has items with incompatible serializations") + } + + return s, nil +} diff --git a/collection_test.go b/collection_test.go index e17d394..d55e089 100644 --- a/collection_test.go +++ b/collection_test.go @@ -62,7 +62,7 @@ func Test_Collection_JSON_Serialize_ok(t *testing.T) { tv.SetItem("b", b) - actual, err := tv.MarshalJSON() + actual, err := tv.Serialize() assert.NoError(t, err) assert.JSONEq(t, string(expected), string(actual)) @@ -124,7 +124,7 @@ func Test_Collection_CBOR_Serialize_ok(t *testing.T) { expected := mustReadFile(t, "testdata/collection-cbor-ok-2.cbor") - b, err := tv.MarshalCBOR() + b, err := tv.Serialize() assert.NoError(t, err) assert.Equal(t, expected, b) } @@ -147,3 +147,41 @@ func Test_Collection_CBOR_Deserialize_and_iterate(t *testing.T) { } } } + +func Test_Collection_detectSerialization_fail(t *testing.T) { + var tv Collection + + var a CMW + a.SetMediaType("application/vnd.a") + a.SetValue([]byte{0x61}) + a.SetSerialization(JSONArray) + + tv.SetItem("a", a) + + var b CMW + b.SetMediaType("application/vnd.b") + b.SetValue([]byte{0x62}) + b.SetSerialization(CBORArray) + + tv.SetItem("b", b) + + s, err := tv.detectSerialization() + assert.EqualError(t, err, "CMW collection has items with incompatible serializations") + assert.Equal(t, UnknownCollectionSerialization, s) +} + +func Test_Collection_Deserialize_JSON_ok(t *testing.T) { + tv := mustReadFile(t, "testdata/collection-ok.json") + + var c Collection + err := c.Deserialize(tv) + assert.NoError(t, err) +} + +func Test_Collection_Deserialize_CBOR_ok(t *testing.T) { + tv := mustReadFile(t, "testdata/collection-cbor-ok.cbor") + + var c Collection + err := c.Deserialize(tv) + assert.NoError(t, err) +}