Skip to content

Commit

Permalink
Check response in Open method:
Browse files Browse the repository at this point in the history
Check status code and response payload
in the Open method to validate the endpoint
is conformant with the response contract.

Update checking of the response error code as
well as if it is nil. This will make sure we
dont error out when a response contains a
value for the error instead of just nil.

Signed-off-by: Jacob Weinstock <[email protected]>
  • Loading branch information
jacobweinstock committed Sep 11, 2023
1 parent ffe1702 commit 55d4044
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
2 changes: 1 addition & 1 deletion providers/rpc/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ Package rpc is a provider that defines an HTTP request/response contract for han
It allows users a simple way to interoperate with an existing/bespoke out-of-band management solution.
The rpc provider request/response payloads are modeled after JSON-RPC 2.0, but are not JSON-RPC 2.0
compliant so as to allow for more flexibility.
compliant so as to allow for more flexibility and interoperability with existing systems.
*/
package rpc
2 changes: 1 addition & 1 deletion providers/rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (p *Provider) createRequest(ctx context.Context, rp RequestPayload) (*http.
return req, nil
}

func (p *Provider) handleResponse(resp *http.Response, reqKeysAndValues []interface{}) (ResponsePayload, error) {
func (p *Provider) handleResponse(resp *http.Response, reqKeysAndValues []any) (ResponsePayload, error) {
kvs := reqKeysAndValues
defer func() {
if !p.LogNotificationsDisabled {
Expand Down
21 changes: 16 additions & 5 deletions providers/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rpc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"hash"
Expand Down Expand Up @@ -181,16 +182,26 @@ func (p *Provider) Open(ctx context.Context) error {
return err
}
p.listenerURL = u
testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), nil)
var buf bytes.Buffer
_ = json.NewEncoder(&buf).Encode(RequestPayload{})
testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), bytes.NewReader(buf.Bytes()))
if err != nil {
return err
}
// test that we can communicate with the rpc consumer.
// and that it responses with the spec contract (Response{}).
resp, err := p.Client.Do(testReq)
if err != nil {
return err
}
if resp.StatusCode >= http.StatusInternalServerError {
return fmt.Errorf("issue on the rpc consumer side, status code: %d", resp.StatusCode)
}

// test that the consumer responses with the expected contract (ResponsePayload{}).
var res ResponsePayload
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return fmt.Errorf("issue with the rpc consumer response: %v", err)
}

return resp.Body.Close()
}
Expand All @@ -216,7 +227,7 @@ func (p *Provider) BootDeviceSet(ctx context.Context, bootDevice string, setPers
if err != nil {
return false, err
}
if resp.Error != nil {
if resp.Error != nil && resp.Error.Code != 0 {
return false, fmt.Errorf("error from rpc consumer: %v", resp.Error)
}

Expand All @@ -239,7 +250,7 @@ func (p *Provider) PowerSet(ctx context.Context, state string) (ok bool, err err
if err != nil {
return ok, err
}
if resp.Error != nil {
if resp.Error != nil && resp.Error.Code != 0 {
return ok, fmt.Errorf("error from rpc consumer: %v", resp.Error)
}

Expand All @@ -260,7 +271,7 @@ func (p *Provider) PowerStateGet(ctx context.Context) (state string, err error)
if err != nil {
return "", err
}
if resp.Error != nil {
if resp.Error != nil && resp.Error.Code != 0 {
return "", fmt.Errorf("error from rpc consumer: %v", resp.Error)
}

Expand Down
6 changes: 1 addition & 5 deletions providers/rpc/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,7 @@ func TestServerErrors(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := New(svr.URL, "127.0.0.1", Secrets{SHA256: {"superSecret1"}})
if err := c.Open(ctx); err != nil {
t.Fatal(err)
}
_, err := c.PowerStateGet(ctx)
if err == nil {
if err := c.Open(ctx); err == nil {
t.Fatal("expected error, got none")
}
})
Expand Down

0 comments on commit 55d4044

Please sign in to comment.