diff --git a/btcutil/go.sum b/btcutil/go.sum index 58e469abe6..5bd6215f1f 100644 --- a/btcutil/go.sum +++ b/btcutil/go.sum @@ -34,6 +34,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= diff --git a/go.mod b/go.mod index 2e3333acc1..4a05d30857 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/aead/siphash v1.0.1 // indirect github.com/decred/dcrd/crypto/blake256 v1.0.0 // indirect github.com/golang/snappy v0.0.4 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect diff --git a/go.sum b/go.sum index 1e39ef3263..e77dfa2f5f 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= diff --git a/rpcclient/chain_test.go b/rpcclient/chain_test.go index 16344ff9c0..cc29f99bfc 100644 --- a/rpcclient/chain_test.go +++ b/rpcclient/chain_test.go @@ -2,9 +2,16 @@ package rpcclient import ( "errors" + "github.com/gorilla/websocket" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" ) +var upgrader = websocket.Upgrader{} + // TestUnmarshalGetBlockChainInfoResult ensures that the SoftForks and // UnifiedSoftForks fields of GetBlockChainInfoResult are properly unmarshaled // when using the expected backend version. @@ -129,3 +136,117 @@ func TestFutureGetBlockCountResultReceiveMarshalsResponseCorrectly(t *testing.T) t.Fatalf("unexpected response: %d (0x%X)", res, res) } } + +func TestClientConnectedToWSServerRunner(t *testing.T) { + type TestTableItem struct { + Name string + TestCase func(t *testing.T) + } + + testTable := []TestTableItem{ + TestTableItem{ + Name: "TestGetChainTxStatsAsyncSuccessTx", + TestCase: func(t *testing.T) { + client, serverReceivedChannel, cleanup := makeClient(t) + defer cleanup() + client.GetChainTxStatsAsync() + + message := <-serverReceivedChannel + if message != "{\"jsonrpc\":\"1.0\",\"method\":\"getchaintxstats\",\"params\":[],\"id\":1}" { + t.Fatalf("received unexpected message: %s", message) + } + }, + }, + TestTableItem{ + Name: "TestGetChainTxStatsAsyncShutdownError", + TestCase: func(t *testing.T) { + client, _, cleanup := makeClient(t) + defer cleanup() + + // a bit of a hack here: since there are multiple places where we read + // from the shutdown channel, and it is not buffered, ensure that a shutdown + // message is sent every time it is read from, this will ensure that + // when client.GetChainTxStatsAsync() gets called, it hits the non-blocking + // read from the shutdown channel + go func() { + type shutdownMessage struct{} + for { + client.shutdown <- shutdownMessage{} + } + }() + + var response *Response = nil + + for response == nil { + respChan := client.GetChainTxStatsAsync() + select { + case response = <-respChan: + default: + } + } + + if response.err == nil || response.err.Error() != "the client has been shutdown" { + t.Fatalf("unexpected error: %s", response.err.Error()) + } + }, + }, + } + + // since these tests rely on concurrency, ensure there is a resonable timeout + // that they should run within + for _, testCase := range testTable { + done := make(chan bool) + + go func() { + t.Run(testCase.Name, testCase.TestCase) + done <- true + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("timeout exceeded for: %s", testCase.Name) + } + } +} + +func makeClient(t *testing.T) (*Client, chan string, func()) { + serverReceivedChannel := make(chan string) + s := httptest.NewServer(http.HandlerFunc(makeUpgradeOnConnect(serverReceivedChannel))) + url := strings.TrimPrefix(s.URL, "http://") + + config := ConnConfig{ + DisableTLS: true, + User: "username", + Pass: "password", + Host: url, + } + + client, err := New(&config, nil) + if err != nil { + t.Fatalf("error when creating new client %s", err.Error()) + } + return client, serverReceivedChannel, func() { + s.Close() + } +} + +func makeUpgradeOnConnect(ch chan string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + _, message, err := c.ReadMessage() + if err != nil { + break + } + + go func() { + ch <- string(message) + }() + } + } +}