Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the CMW collection #4

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions cmw.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ import (
type Serialization uint

const (
JSONArray = Serialization(iota)
UnknownSerialization = Serialization(iota)
JSONArray
CBORArray
CBORTag
Unknown
)

// a CMW object holds the internal representation of a RATS conceptual message
// wrapper
type CMW struct {
typ Type
val Value
ind Indicator
typ Type
val Value
ind Indicator
serialization Serialization
}

func (o *CMW) SetMediaType(v string) { _ = o.typ.Set(v) }
Expand All @@ -41,41 +42,55 @@ func (o *CMW) SetIndicators(indicators ...Indicator) {

o.ind = v
}
func (o *CMW) SetSerialization(s Serialization) { o.serialization = s }

func (o CMW) GetValue() []byte { return o.val }
func (o CMW) GetType() string { return o.typ.String() }
func (o CMW) GetIndicator() Indicator { return o.ind }
func (o CMW) GetValue() []byte { return o.val }
func (o CMW) GetType() string { return o.typ.String() }
func (o CMW) GetIndicator() Indicator { return o.ind }
func (o CMW) GetSerialization() Serialization { return o.serialization }

// Deserialize a CMW
func (o *CMW) Deserialize(b []byte) error {
switch sniff(b) {
s := sniff(b)

o.serialization = s

switch s {
case JSONArray:
return o.UnmarshalJSON(b)
case CBORArray:
case CBORArray, CBORTag:
return o.UnmarshalCBOR(b)
case CBORTag:
return o.UnmarshalCBORTag(b)
}

return errors.New("unknown CMW format")
}

// Serialize a CMW according to the provided Serialization
func (o CMW) Serialize(s Serialization) ([]byte, error) {
// Serialize a CMW according to its provided Serialization
func (o CMW) Serialize() ([]byte, error) {
s := o.serialization
switch s {
case JSONArray:
return o.MarshalJSON()
case CBORArray:
case CBORArray, CBORTag:
return o.MarshalCBOR()
case CBORTag:
return o.MarshalCBORTag()
}
return nil, fmt.Errorf("invalid serialization format %d", s)
}

func (o CMW) MarshalJSON() ([]byte, error) { return arrayEncode(json.Marshal, &o) }
func (o CMW) MarshalCBOR() ([]byte, error) { return arrayEncode(cbor.Marshal, &o) }

func (o CMW) MarshalCBORTag() ([]byte, error) {
func (o CMW) MarshalCBOR() ([]byte, error) {
s := o.serialization
switch s {
case CBORArray:
return arrayEncode(cbor.Marshal, &o)
case CBORTag:
return o.encodeCBORTag()
}
return nil, fmt.Errorf("invalid serialization format: want CBORArray or CBORTag, got %d", s)
}

func (o CMW) encodeCBORTag() ([]byte, error) {
var (
tag cbor.RawTag
err error
Expand All @@ -99,14 +114,26 @@ func (o CMW) MarshalCBORTag() ([]byte, error) {
}

func (o *CMW) UnmarshalCBOR(b []byte) error {
return arrayDecode[cbor.RawMessage](cbor.Unmarshal, b, o)
if arrayDecode[cbor.RawMessage](cbor.Unmarshal, b, o) == nil {
o.serialization = CBORArray
return nil
}

if o.decodeCBORTag(b) == nil {
// the serialization attribute is set by decodeCBORTag
return nil
}

return errors.New("invalid CBOR-encoded CMW")
}

func (o *CMW) UnmarshalJSON(b []byte) error {
return arrayDecode[json.RawMessage](json.Unmarshal, b, o)
err := arrayDecode[json.RawMessage](json.Unmarshal, b, o)
o.serialization = JSONArray
return err
}

func (o *CMW) UnmarshalCBORTag(b []byte) error {
func (o *CMW) decodeCBORTag(b []byte) error {
var (
v cbor.RawTag
m []byte
Expand All @@ -123,13 +150,14 @@ func (o *CMW) UnmarshalCBORTag(b []byte) error {

_ = o.typ.Set(v.Number)
_ = o.val.Set(m)
o.serialization = CBORTag

return nil
}

func sniff(b []byte) Serialization {
if len(b) == 0 {
return Unknown
return UnknownSerialization
}

if b[0] == 0x82 || b[0] == 0x83 {
Expand All @@ -140,12 +168,12 @@ func sniff(b []byte) Serialization {
return JSONArray
}

return Unknown
return UnknownSerialization
}

type (
arrayDecoder func([]byte, interface{}) error
arrayEncoder func(interface{}) ([]byte, error)
arrayDecoder func([]byte, any) error
arrayEncoder func(any) ([]byte, error)
)

func arrayDecode[V json.RawMessage | cbor.RawMessage](
Expand Down Expand Up @@ -185,7 +213,7 @@ func arrayEncode(enc arrayEncoder, o *CMW) ([]byte, error) {
return nil, fmt.Errorf("type and value MUST be set in CMW")
}

a := []interface{}{o.typ, o.val}
a := []any{o.typ, o.val}

if !o.ind.Empty() {
a = append(a, o.ind)
Expand Down
47 changes: 35 additions & 12 deletions cmw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{uint16(30001)},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
JSONArray,
},
},
{
Expand All @@ -93,6 +94,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{"application/vnd.intel.sgx"},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
JSONArray,
},
},
{
Expand All @@ -102,6 +104,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{"application/vnd.intel.sgx"},
[]byte{0xde, 0xad, 0xbe, 0xef},
testIndicator,
JSONArray,
},
},
{
Expand All @@ -112,6 +115,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{uint16(30001)},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
CBORArray,
},
},
{
Expand All @@ -127,6 +131,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{string("application/vnd.intel.sgx")},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
CBORArray,
},
},
{
Expand All @@ -139,6 +144,7 @@ func Test_Deserialize_ok(t *testing.T) {
Type{uint64(1668576818)},
[]byte{0xde, 0xad, 0xbe, 0xef},
IndicatorNone,
CBORTag,
},
},
}
Expand Down Expand Up @@ -203,8 +209,9 @@ func Test_Serialize_JSONArray_ok(t *testing.T) {
cmw.SetMediaType(tt.tv.typ)
cmw.SetValue(tt.tv.val)
cmw.SetIndicators(tt.tv.ind...)
cmw.SetSerialization(JSONArray)

actual, err := cmw.Serialize(JSONArray)
actual, err := cmw.Serialize()
assert.NoError(t, err)
assert.JSONEq(t, tt.exp, string(actual))
})
Expand Down Expand Up @@ -259,8 +266,9 @@ func Test_Serialize_CBORArray_ok(t *testing.T) {
cmw.SetContentFormat(tt.tv.typ)
cmw.SetValue(tt.tv.val)
cmw.SetIndicators(tt.tv.ind...)
cmw.SetSerialization(CBORArray)

actual, err := cmw.Serialize(CBORArray)
actual, err := cmw.Serialize()
assert.NoError(t, err)
assert.Equal(t, tt.exp, actual)
})
Expand Down Expand Up @@ -310,8 +318,9 @@ func Test_Serialize_CBORTag_ok(t *testing.T) {

cmw.SetTagNumber(tt.tv.typ)
cmw.SetValue(tt.tv.val)
cmw.SetSerialization(CBORTag)

actual, err := cmw.Serialize(CBORTag)
actual, err := cmw.Serialize()
assert.NoError(t, err)
assert.Equal(t, tt.exp, actual)
})
Expand Down Expand Up @@ -441,7 +450,7 @@ func Test_Deserialize_CBORArray_ko(t *testing.T) {
0x82, 0xfb, 0x40, 0x8f, 0x41, 0xd7, 0x0a, 0x3d, 0x70, 0xa4,
0x44, 0xde, 0xad, 0xbe, 0xef,
},
`unmarshaling type: cannot unmarshal 1000.230000 into uint16`,
`invalid CBOR-encoded CMW`,
},
{
"overflow for type",
Expand All @@ -450,7 +459,7 @@ func Test_Deserialize_CBORArray_ko(t *testing.T) {
0x82, 0x1a, 0x00, 0x01, 0x00, 0x00, 0x44, 0xde, 0xad, 0xbe,
0xef,
},
`unmarshaling type: cannot unmarshal 65536 into uint16`,
`invalid CBOR-encoded CMW`,
},
{
"bad type (float) for value",
Expand All @@ -459,7 +468,7 @@ func Test_Deserialize_CBORArray_ko(t *testing.T) {
0x82, 0x19, 0xff, 0xff, 0xfb, 0x3f, 0xf3, 0x33, 0x33, 0x33,
0x33, 0x33, 0x33,
},
`unmarshaling value: cannot decode value: cbor: cannot unmarshal primitives into Go value of type []uint8`,
`invalid CBOR-encoded CMW`,
},
}

Expand All @@ -481,13 +490,13 @@ func Test_Deserialize_CBORTag(t *testing.T) {
{
"empty CBOR Tag",
[]byte{0xda, 0x63, 0x74, 0x01, 0x01},
`unmarshal CMW CBOR Tag bstr-wrapped value: EOF`,
`invalid CBOR-encoded CMW`,
},
{
"bad type (uint) for value",
// echo "1668546817(1)" | diag2cbor.rb | xxd -i
[]byte{0xda, 0x63, 0x74, 0x01, 0x01, 0x01},
`unmarshal CMW CBOR Tag bstr-wrapped value: cbor: cannot unmarshal positive integer into Go value of type []uint8`,
`invalid CBOR-encoded CMW`,
},
}

Expand All @@ -504,28 +513,42 @@ func Test_EncodeArray_sanitize_input(t *testing.T) {
var cmw CMW

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.EqualError(t, err, "type and value MUST be set in CMW")
}

cmw.SetValue([]byte{0xff})

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.EqualError(t, err, "type and value MUST be set in CMW")
}

cmw.SetMediaType("")

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.EqualError(t, err, "type and value MUST be set in CMW")
}

cmw.SetContentFormat(0)

for _, s := range []Serialization{CBORArray, JSONArray} {
_, err := cmw.Serialize(s)
cmw.SetSerialization(s)
_, err := cmw.Serialize()
assert.NoError(t, err)
}
}

func Test_Serialize_invalid_serialization(t *testing.T) {
var tv CMW

tv.SetMediaType("application/vnd.x")
tv.SetValue([]byte{0x00})

_, err := tv.Serialize()
assert.Error(t, err, "TPDP")
}
Loading
Loading