Skip to content

Commit

Permalink
Implement the CMW collection
Browse files Browse the repository at this point in the history
Fix #3

Signed-off-by: Thomas Fossati <[email protected]>
  • Loading branch information
thomas-fossati committed Dec 30, 2023
1 parent eb71245 commit 1e721a5
Show file tree
Hide file tree
Showing 13 changed files with 529 additions and 41 deletions.
82 changes: 55 additions & 27 deletions cmw.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ import (
type Serialization uint

const (
JSONArray = Serialization(iota)
UnknownSerialization = Serialization(iota)
JSONArray
CBORArray
CBORTag
Unknown
)

// a CMW object holds the internal representation of a RATS conceptual message
// wrapper
type CMW struct {
typ Type
val Value
ind Indicator
typ Type
val Value
ind Indicator
serialization Serialization
}

func (o *CMW) SetMediaType(v string) { _ = o.typ.Set(v) }
Expand All @@ -41,41 +42,55 @@ func (o *CMW) SetIndicators(indicators ...Indicator) {

o.ind = v
}
func (o *CMW) SetSerialization(s Serialization) { o.serialization = s }

func (o CMW) GetValue() []byte { return o.val }
func (o CMW) GetType() string { return o.typ.String() }
func (o CMW) GetIndicator() Indicator { return o.ind }
func (o CMW) GetValue() []byte { return o.val }
func (o CMW) GetType() string { return o.typ.String() }
func (o CMW) GetIndicator() Indicator { return o.ind }
func (o CMW) GetSerialization() Serialization { return o.serialization }

// Deserialize a CMW
func (o *CMW) Deserialize(b []byte) error {
switch sniff(b) {
s := sniff(b)

o.serialization = s

switch s {
case JSONArray:
return o.UnmarshalJSON(b)
case CBORArray:
case CBORArray, CBORTag:
return o.UnmarshalCBOR(b)
case CBORTag:
return o.UnmarshalCBORTag(b)
}

return errors.New("unknown CMW format")
}

// Serialize a CMW according to the provided Serialization
func (o CMW) Serialize(s Serialization) ([]byte, error) {
// Serialize a CMW according to its provided Serialization
func (o CMW) Serialize() ([]byte, error) {
s := o.serialization
switch s {
case JSONArray:
return o.MarshalJSON()
case CBORArray:
case CBORArray, CBORTag:
return o.MarshalCBOR()
case CBORTag:
return o.MarshalCBORTag()
}
return nil, fmt.Errorf("invalid serialization format %d", s)
}

func (o CMW) MarshalJSON() ([]byte, error) { return arrayEncode(json.Marshal, &o) }
func (o CMW) MarshalCBOR() ([]byte, error) { return arrayEncode(cbor.Marshal, &o) }

func (o CMW) MarshalCBORTag() ([]byte, error) {
func (o CMW) MarshalCBOR() ([]byte, error) {
s := o.serialization
switch s {
case CBORArray:
return arrayEncode(cbor.Marshal, &o)
case CBORTag:
return o.encodeCBORTag()
}
return nil, fmt.Errorf("invalid serialization format: want CBORArray or CBORTag, got %d", s)
}

func (o CMW) encodeCBORTag() ([]byte, error) {
var (
tag cbor.RawTag
err error
Expand All @@ -99,14 +114,26 @@ func (o CMW) MarshalCBORTag() ([]byte, error) {
}

func (o *CMW) UnmarshalCBOR(b []byte) error {
return arrayDecode[cbor.RawMessage](cbor.Unmarshal, b, o)
if arrayDecode[cbor.RawMessage](cbor.Unmarshal, b, o) == nil {
o.serialization = CBORArray
return nil
}

if o.decodeCBORTag(b) == nil {
// the serialization attribute is set by decodeCBORTag
return nil
}

return errors.New("invalid CBOR-encoded CMW")
}

func (o *CMW) UnmarshalJSON(b []byte) error {
return arrayDecode[json.RawMessage](json.Unmarshal, b, o)
err := arrayDecode[json.RawMessage](json.Unmarshal, b, o)
o.serialization = JSONArray
return err
}

func (o *CMW) UnmarshalCBORTag(b []byte) error {
func (o *CMW) decodeCBORTag(b []byte) error {
var (
v cbor.RawTag
m []byte
Expand All @@ -123,13 +150,14 @@ func (o *CMW) UnmarshalCBORTag(b []byte) error {

_ = o.typ.Set(v.Number)
_ = o.val.Set(m)
o.serialization = CBORTag

return nil
}

func sniff(b []byte) Serialization {
if len(b) == 0 {
return Unknown
return UnknownSerialization
}

if b[0] == 0x82 || b[0] == 0x83 {
Expand All @@ -140,12 +168,12 @@ func sniff(b []byte) Serialization {
return JSONArray
}

return Unknown
return UnknownSerialization
}

type (
arrayDecoder func([]byte, interface{}) error
arrayEncoder func(interface{}) ([]byte, error)
arrayDecoder func([]byte, any) error
arrayEncoder func(any) ([]byte, error)
)

func arrayDecode[V json.RawMessage | cbor.RawMessage](
Expand Down Expand Up @@ -185,7 +213,7 @@ func arrayEncode(enc arrayEncoder, o *CMW) ([]byte, error) {
return nil, fmt.Errorf("type and value MUST be set in CMW")
}

a := []interface{}{o.typ, o.val}
a := []any{o.typ, o.val}

if !o.ind.Empty() {
a = append(a, o.ind)
Expand Down
47 changes: 35 additions & 12 deletions cmw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{uint16(30001)},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
JSONArray,
},
},
{
Expand All @@ -93,6 +94,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{"application/vnd.intel.sgx"},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
JSONArray,
},
},
{
Expand All @@ -102,6 +104,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{"application/vnd.intel.sgx"},
[]byte{0xde, 0xad, 0xbe, 0xef},
testIndicator,
JSONArray,
},
},
{
Expand All @@ -112,6 +115,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{uint16(30001)},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
CBORArray,
},
},
{
Expand All @@ -127,6 +131,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{string("application/vnd.intel.sgx")},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
CBORArray,
},
},
{
Expand All @@ -139,6 +144,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{uint64(1668576818)},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
CBORTag,
},
},
}
Expand Down Expand Up @@ -203,8 +209,9 @@ func Test_Serialize_JSONArray_ok(t *testing.T) {
cmw.SetMediaType(tt.tv.typ)
cmw.SetValue(tt.tv.val)
cmw.SetIndicators(tt.tv.ind...)
cmw.SetSerialization(JSONArray)

actual, err := cmw.Serialize(JSONArray)
actual, err := cmw.Serialize()
assert.NoError(t, err)
assert.JSONEq(t, tt.exp, string(actual))
})
Expand Down Expand Up @@ -259,8 +266,9 @@ func Test_Serialize_CBORArray_ok(t *testing.T) {
cmw.SetContentFormat(tt.tv.typ)
cmw.SetValue(tt.tv.val)
cmw.SetIndicators(tt.tv.ind...)
cmw.SetSerialization(CBORArray)

actual, err := cmw.Serialize(CBORArray)
actual, err := cmw.Serialize()
assert.NoError(t, err)
assert.Equal(t, tt.exp, actual)
})
Expand Down Expand Up @@ -310,8 +318,9 @@ func Test_Serialize_CBORTag_ok(t *testing.T) {

cmw.SetTagNumber(tt.tv.typ)
cmw.SetValue(tt.tv.val)
cmw.SetSerialization(CBORTag)

actual, err := cmw.Serialize(CBORTag)
actual, err := cmw.Serialize()
assert.NoError(t, err)
assert.Equal(t, tt.exp, actual)
})
Expand Down Expand Up @@ -441,7 +450,7 @@ func Test_Deserialize_CBORArray_ko(t *testing.T) {
0x82, 0xfb, 0x40, 0x8f, 0x41, 0xd7, 0x0a, 0x3d, 0x70, 0xa4,
0x44, 0xde, 0xad, 0xbe, 0xef,
},
`unmarshaling type: cannot unmarshal 1000.230000 into uint16`,
`invalid CBOR-encoded CMW`,
},
{
"overflow for type",
Expand All @@ -450,7 +459,7 @@ func Test_Deserialize_CBORArray_ko(t *testing.T) {
0x82, 0x1a, 0x00, 0x01, 0x00, 0x00, 0x44, 0xde, 0xad, 0xbe,
0xef,
},
`unmarshaling type: cannot unmarshal 65536 into uint16`,
`invalid CBOR-encoded CMW`,
},
{
"bad type (float) for value",
Expand All @@ -459,7 +468,7 @@ func Test_Deserialize_CBORArray_ko(t *testing.T) {
0x82, 0x19, 0xff, 0xff, 0xfb, 0x3f, 0xf3, 0x33, 0x33, 0x33,
0x33, 0x33, 0x33,
},
`unmarshaling value: cannot decode value: cbor: cannot unmarshal primitives into Go value of type []uint8`,
`invalid CBOR-encoded CMW`,
},
}

Expand All @@ -481,13 +490,13 @@ func Test_Deserialize_CBORTag(t *testing.T) {
{
"empty CBOR Tag",
[]byte{0xda, 0x63, 0x74, 0x01, 0x01},
`unmarshal CMW CBOR Tag bstr-wrapped value: EOF`,
`invalid CBOR-encoded CMW`,
},
{
"bad type (uint) for value",
// echo "1668546817(1)" | diag2cbor.rb | xxd -i
[]byte{0xda, 0x63, 0x74, 0x01, 0x01, 0x01},
`unmarshal CMW CBOR Tag bstr-wrapped value: cbor: cannot unmarshal positive integer into Go value of type []uint8`,
`invalid CBOR-encoded CMW`,
},
}

Expand All @@ -504,28 +513,42 @@ func Test_EncodeArray_sanitize_input(t *testing.T) {
var cmw CMW

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.EqualError(t, err, "type and value MUST be set in CMW")
}

cmw.SetValue([]byte{0xff})

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.EqualError(t, err, "type and value MUST be set in CMW")
}

cmw.SetMediaType("")

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.EqualError(t, err, "type and value MUST be set in CMW")
}

cmw.SetContentFormat(0)

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.NoError(t, err)
}
}

func Test_Serialize_invalid_serialization(t *testing.T) {
var tv CMW

tv.SetMediaType("application/vnd.x")
tv.SetValue([]byte{0x00})

_, err := tv.Serialize()
assert.Error(t, err, "TPDP")
}
Loading

0 comments on commit 1e721a5

Please sign in to comment.