diff --git a/beacon/versioned_beacon_state.go b/beacon/versioned_beacon_state.go index 32841707..bbbb2eeb 100644 --- a/beacon/versioned_beacon_state.go +++ b/beacon/versioned_beacon_state.go @@ -82,14 +82,22 @@ func CreateVersionedState(state interface{}) (spec.VersionedBeaconState, error) func UnmarshalSSZVersionedBeaconState(data []byte) (*spec.VersionedBeaconState, error) { beaconState := &spec.VersionedBeaconState{} + denebBeaconState := &deneb.BeaconState{} // Try to unmarshal using Deneb - err := beaconState.Deneb.UnmarshalSSZ(data) + err := denebBeaconState.UnmarshalSSZ(data) if err != nil { // If Deneb fails, try Capella - err = beaconState.Capella.UnmarshalSSZ(data) + capellaBeaconState := &capella.BeaconState{} + err = capellaBeaconState.UnmarshalSSZ(data) if err != nil { return nil, err + } else { + beaconState.Capella = capellaBeaconState + beaconState.Version = spec.DataVersionCapella } + } else { + beaconState.Deneb = denebBeaconState + beaconState.Version = spec.DataVersionDeneb } return beaconState, nil @@ -97,10 +105,14 @@ func UnmarshalSSZVersionedBeaconState(data []byte) (*spec.VersionedBeaconState, func MarshalSSZVersionedBeaconState(beaconState spec.VersionedBeaconState) ([]byte, error) { var data []byte + var err error // Try to marshal using Deneb - data, err := beaconState.Deneb.MarshalSSZ() - if err != nil { - // If Deneb fails, try Capella + if beaconState.Version == spec.DataVersionDeneb { + data, err = beaconState.Deneb.MarshalSSZ() + if err != nil { + return nil, err + } + } else { data, err = beaconState.Capella.MarshalSSZ() if err != nil { return nil, err diff --git a/merkle_util_test.go b/merkle_util_test.go index 5e370e1c..ba364842 100644 --- a/merkle_util_test.go +++ b/merkle_util_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/attestantio/go-eth2-client/spec" + "github.com/attestantio/go-eth2-client/spec/altair" "github.com/attestantio/go-eth2-client/spec/capella" "github.com/attestantio/go-eth2-client/spec/deneb" "github.com/attestantio/go-eth2-client/spec/phase0" @@ -250,6 +251,96 @@ func TestProveWithdrawals(t *testing.T) { assert.True(t, flag, "Historical Summary Block Root Proof %v failed") } +func TestUnmarshalSSZVersionedBeaconStateDeneb(t *testing.T) { + oracleStateBytes, err := oracleState.MarshalSSZ() + if err != nil { + fmt.Println("error", err) + } + + versionedBeaconState, err := beacon.UnmarshalSSZVersionedBeaconState(oracleStateBytes) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBeaconState.Version, spec.DataVersionDeneb, "Version %v failed") + + versionedBeaconStateBytes, err := versionedBeaconState.Deneb.MarshalSSZ() + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBeaconStateBytes, oracleStateBytes, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + +func TestUnmarshalSSZVersionedBeaconStateCapella(t *testing.T) { + var capellaState capella.BeaconState + capellaStateJSON, err := ParseJSONFileCapella("data/goerli_slot_6409723.json") + if err != nil { + fmt.Println("error", err) + } + ParseCapellaBeaconStateFromJSON(*capellaStateJSON, &capellaState) + + capellaStateBytes, err := capellaState.MarshalSSZ() + if err != nil { + fmt.Println("error", err) + } + + versionedBeaconState, err := beacon.UnmarshalSSZVersionedBeaconState(capellaStateBytes) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBeaconState.Version, spec.DataVersionCapella, "Version %v failed") + + versionedBeaconStateBytes, err := versionedBeaconState.Capella.MarshalSSZ() + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBeaconStateBytes, capellaStateBytes, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + +func TestMarshalSSZVersionedBeaconStateDeneb(t *testing.T) { + oracleStateBytes, err := oracleState.MarshalSSZ() + if err != nil { + fmt.Println("error", err) + } + versionedBeaconState, err := beacon.CreateVersionedState(&oracleState) + if err != nil { + fmt.Println("error", err) + } + + versionedBeaconStateBytes, err := beacon.MarshalSSZVersionedBeaconState(versionedBeaconState) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBeaconStateBytes, oracleStateBytes, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + +func TestMarshalSSZVersionedBeaconStateCapella(t *testing.T) { + var capellaState capella.BeaconState + capellaStateJSON, err := ParseJSONFileCapella("data/goerli_slot_6409723.json") + if err != nil { + fmt.Println("error", err) + } + ParseCapellaBeaconStateFromJSON(*capellaStateJSON, &capellaState) + + capellaStateBytes, err := capellaState.MarshalSSZ() + if err != nil { + fmt.Println("error", err) + } + versionedBeaconState, err := beacon.CreateVersionedState(&capellaState) + if err != nil { + fmt.Println("error", err) + } + + versionedBeaconStateBytes, err := beacon.MarshalSSZVersionedBeaconState(versionedBeaconState) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBeaconStateBytes, capellaStateBytes, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + func TestGenerateWithdrawalCredentialsProof(t *testing.T) { // picking up one random validator index @@ -270,6 +361,64 @@ func TestGenerateWithdrawalCredentialsProof(t *testing.T) { assert.True(t, flag, "Proof %v failed") } +func TestCreateVersionedSignedBlockDeneb(t *testing.T) { + block := deneb.BeaconBlock{} + versionedBlock, err := beacon.CreateVersionedSignedBlock(block) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBlock.Version, spec.DataVersionDeneb, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + +func TestCreateVersionedSignedBlockCapella(t *testing.T) { + block := capella.BeaconBlock{} + versionedBlock, err := beacon.CreateVersionedSignedBlock(block) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedBlock.Version, spec.DataVersionCapella, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + +func TestCreateVersionedSignedBlockAltair(t *testing.T) { + block := altair.BeaconBlock{} + _, err := beacon.CreateVersionedSignedBlock(block) + if err != nil { + fmt.Println("error", err) + } + assert.NotNil(t, err, "error %v was nil") +} + +func TestCreateVersionedBeaconStateDeneb(t *testing.T) { + oracleState := deneb.BeaconState{} + versionedState, err := beacon.CreateVersionedState(&oracleState) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedState.Version, spec.DataVersionDeneb, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + +func TestCreateVersionedBeaconStateCapella(t *testing.T) { + state := capella.BeaconState{} + versionedState, err := beacon.CreateVersionedState(&state) + if err != nil { + fmt.Println("error", err) + } + assert.Equal(t, versionedState.Version, spec.DataVersionCapella, "Version %v failed") + assert.Nil(t, err, "Error %v failed") +} + +func TestCreateVersionedBeaconStateAltair(t *testing.T) { + state := altair.BeaconState{} + _, err := beacon.CreateVersionedState(&state) + if err != nil { + fmt.Println("error", err) + } + assert.NotNil(t, err, "error %v was nil") +} + func TestProveValidatorBalanceAgainstValidatorBalanceList(t *testing.T) { validatorIndex := phase0.ValidatorIndex(REPOINTED_VALIDATOR_INDEX)