diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml new file mode 100644 index 0000000..109eff0 --- /dev/null +++ b/.github/workflows/build-test.yml @@ -0,0 +1,14 @@ +on: [push, pull_request] +name: Test +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: 1.17 + - name: Checkout code + uses: actions/checkout@v2 + - name: Test + run: go test ./... diff --git a/cmd/cli/bulk.go b/cmd/cli/bulk.go new file mode 100644 index 0000000..59b75e8 --- /dev/null +++ b/cmd/cli/bulk.go @@ -0,0 +1,244 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "strings" + "sync" + + "github.com/multiplay/go-svrquery/lib/svrquery" + "github.com/multiplay/go-svrquery/lib/svrquery/protocol" +) + +const ( + numWorkers = 100 +) + +var ( + errNoItem = errors.New("no item") + errEntryInvalid = errors.New("invalid entry") +) + +// BulkResponseItem contains the information about the query being performed +// against a single server. +type BulkResponseItem struct { + Address string `json:"address,omitempty"` + ServerInfo *BulkResponseServerInfoItem `json:"serverInfo,omitempty"` + Error string `json:"error,omitempty"` +} + +// encode writes the JSON encoded version of i to w using the encoder e which writes to b. +// It strips the trailing \n from the output before writing to w. +func (i *BulkResponseItem) encode(w io.Writer, b *bytes.Buffer, e *json.Encoder) error { + defer b.Reset() + + if err := e.Encode(i); err != nil { + return fmt.Errorf("encode item %v: %w", i, err) + } + + if _, err := w.Write(bytes.TrimRight(b.Bytes(), "\n")); err != nil { + return fmt.Errorf("write item: %w", err) + } + + return nil +} + +// BulkResponseServerInfoItem containing basic server information. +type BulkResponseServerInfoItem struct { + CurrentPlayers int64 `json:"currentPlayers"` + MaxPlayers int64 `json:"maxPlayers"` + Map string `json:"map"` +} + +// queryBulk queries a bulk set of servers from a query file writing the JSON results to output. +func queryBulk(file string, output io.Writer) (err error) { + work := make(chan string, numWorkers) // Buffered to ensure we can busy all workers. + results := make(chan BulkResponseItem, numWorkers) // Buffered to improve worker concurrency. + + // Create a pool of workers to process work. + var wgWorkers sync.WaitGroup + wgWorkers.Add(numWorkers) + for w := 1; w <= numWorkers; w++ { + c, err := svrquery.NewBulkClient() + if err != nil { + close(work) // Ensure that existing workers return. + return fmt.Errorf("bulk client: %w", err) + } + + go func() { + defer wgWorkers.Done() + worker(work, results, c) + }() + } + + // Create a writer to write the results to output as they become available. + errc := make(chan error) + go func() { + errc <- writer(output, results) + }() + + // Queue work onto the channel. + if err = producer(file, work); err != nil { + err = fmt.Errorf("producer: %w", err) + } + + // Wait for all workers to complete so that we can safely close results + // that will trigger writer to return once its processed all results. + wgWorkers.Wait() + close(results) + + if werr := <-errc; werr != nil { + if err != nil { + return fmt.Errorf("%w, writer: %s", err, werr) + } + return fmt.Errorf("writer: %w", werr) + } + + return err +} + +// writer writes results as JSON encoded array to w. +func writer(w io.Writer, results <-chan BulkResponseItem) (err error) { + if _, err = w.Write([]byte{'['}); err != nil { + return fmt.Errorf("write header: %w", err) + } + + // Do our best to write valid JSON by ensuring we always write + // the closing ]. If a previous encode fails, this could still + // be insufficient. + defer func() { + if _, werr := w.Write([]byte("]\n")); werr != nil { + werr = fmt.Errorf("write trailer: %w", err) + if err == nil { + err = werr + } + } + }() + + var b bytes.Buffer + e := json.NewEncoder(&b) + + // Process the first item before looping so separating + // comma can be written easily. + i, ok := <-results + if !ok { + return nil + } + + if err := i.encode(w, &b, e); err != nil { + return err + } + + for i := range results { + if _, err := w.Write([]byte(",")); err != nil { + return fmt.Errorf("write set: %w", err) + } + + if err := i.encode(w, &b, e); err != nil { + return err + } + } + + return nil +} + +// producer reads lines from file sending them to work. +// work will be closed before return. +func producer(file string, work chan<- string) error { + defer close(work) + + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + + s := bufio.NewScanner(f) + for s.Scan() { + work <- s.Text() + } + + return s.Err() +} + +// worker calls processBulkEntry for each item read from work, writing the result to results. +func worker(work <-chan string, results chan<- BulkResponseItem, client *svrquery.BulkClient) { + for e := range work { + results <- processBulkEntry(e, client) + } +} + +// processBulkEntry decodes and processes an entry returning an item containing the result or error. +func processBulkEntry(entry string, client *svrquery.BulkClient) (item BulkResponseItem) { + querySection, addressSection, err := parseEntry(entry) + if err != nil { + item.Error = fmt.Sprintf("parse file entry: %s", err) + return item + } + + item.Address = addressSection + + // If the query contains any options retrieve and use them. + querySection, options, err := parseOptions(querySection) + if err != nil { + item.Error = err.Error() + return item + } + + resp, err := client.Query(querySection, addressSection, options...) + if err != nil { + item.Error = fmt.Sprintf("query client: %s", err) + return item + } + + item.ServerInfo = &BulkResponseServerInfoItem{ + CurrentPlayers: resp.NumClients(), + MaxPlayers: resp.MaxClients(), + Map: "UNKNOWN", + } + + if currentMap, ok := resp.(protocol.Mapper); ok { + item.ServerInfo.Map = currentMap.Map() + } + return item +} + +// pareEntry parses the details from entry returning the query and address sections. +func parseEntry(entry string) (querySection, addressSection string, err error) { + entry = strings.TrimSpace(entry) + if entry == "" { + return "", "", fmt.Errorf("parse entry %q: %w", entry, errNoItem) + } + + sections := strings.Split(entry, " ") + if len(sections) != 2 { + return "", "", fmt.Errorf("%w %q: wrong number of sections %d", errEntryInvalid, entry, len(sections)) + } + + return sections[0], sections[1], nil +} + +// parseOptions parses querySection returning the baseQuery and query options. +func parseOptions(querySection string) (baseQuery string, options []svrquery.Option, err error) { + options = make([]svrquery.Option, 0) + protocolSections := strings.Split(querySection, ",") + for i := 1; i < len(protocolSections); i++ { + keyVal := strings.SplitN(protocolSections[i], "=", 2) + if len(keyVal) != 2 { + return "", nil, fmt.Errorf("key value pair invalid: %v", keyVal) + + } + + // Only support key at the moment. + switch strings.ToLower(keyVal[0]) { + case "key": + options = append(options, svrquery.WithKey(keyVal[1])) + } + } + return protocolSections[0], options, nil +} diff --git a/cmd/cli/bulk_test.go b/cmd/cli/bulk_test.go new file mode 100644 index 0000000..e0807ee --- /dev/null +++ b/cmd/cli/bulk_test.go @@ -0,0 +1,101 @@ +package main + +import ( + "testing" + + "github.com/multiplay/go-svrquery/lib/svrquery" + "github.com/stretchr/testify/require" +) + +func TestParseEntry(t *testing.T) { + testCases := []struct { + name string + input string + expQuery string + expAddress string + expErr error + }{ + { + name: "ok", + input: "sqp 1.2.3.4:1234", + expQuery: "sqp", + expAddress: "1.2.3.4:1234", + }, + { + name: "empty line", + input: "", + expErr: errNoItem, + }, + { + name: "invalid entry", + input: "sqp 1.2.3.4:1234 extra", + expErr: errEntryInvalid, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query, addr, err := parseEntry(tc.input) + if err != nil { + require.ErrorIs(t, err, tc.expErr) + return + } + require.NoError(t, err) + require.Equal(t, tc.expQuery, query) + require.Equal(t, tc.expAddress, addr) + }) + } +} + +func TestCreateClient(t *testing.T) { + testCases := []struct { + name string + query string + expQuery string + expKey string + expErr error + }{ + { + name: "ok", + query: "tf2e", + expQuery: "tf2e", + }, + { + name: "with_key", + query: "tf2e,key=val", + expKey: "val", + expQuery: "tf2e", + }, + { + name: "with_unsupported_other", + query: "tf2e,other=val", + expQuery: "tf2e", + }, + { + name: "invalid entry", + query: "tf2e", + expQuery: "tf2e", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + baseQuery, options, err := parseOptions(tc.query) + if err != nil { + require.ErrorIs(t, err, tc.expErr) + return + } + require.NoError(t, err) + require.Equal(t, tc.expQuery, baseQuery) + + // Validate key setting + if tc.expKey != "" { + require.Len(t, options, 1) + c := svrquery.Client{} + require.NoError(t, options[0](&c)) + require.Equal(t, tc.expKey, c.Key()) + } + require.NotNil(t, options) + }) + } +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 4800db6..d8e0c46 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -17,11 +17,21 @@ import ( func main() { clientAddr := flag.String("addr", "", "Address to connect to e.g. 127.0.0.1:12345") proto := flag.String("proto", "", "Protocol e.g. sqp, tf2e, tf2e-v7, tf2e-v8") + key := flag.String("key", "", "Key to use to authenticate") + file := flag.String("file", "", "Bulk file to execute to get basic server information") serverAddr := flag.String("server", "", "Address to start server e.g. 127.0.0.1:12121, :23232") flag.Parse() l := log.New(os.Stderr, "", 0) + if *file != "" { + // Use bulk file mode + if err := queryBulk(*file, os.Stdout); err != nil { + l.Fatal(err) + } + return + } + if *serverAddr != "" && *clientAddr != "" { bail(l, "Cannot run both a server and a client. Specify either -addr OR -server flags") } @@ -36,20 +46,25 @@ func main() { if *proto == "" { bail(l, "Protocol required in server mode") } - queryMode(l, *proto, *clientAddr) + queryMode(l, *proto, *clientAddr, *key) default: bail(l, "Please supply some options") } } -func queryMode(l *log.Logger, proto, address string) { - if err := query(proto, address); err != nil { +func queryMode(l *log.Logger, proto, address, key string) { + if err := query(proto, address, key); err != nil { l.Fatal(err) } } -func query(proto, address string) error { - c, err := svrquery.NewClient(proto, address) +func query(proto, address, key string) error { + options := make([]svrquery.Option, 0) + if key != "" { + options = append(options, svrquery.WithKey(key)) + } + + c, err := svrquery.NewClient(proto, address, options...) if err != nil { return err } @@ -84,6 +99,9 @@ func server(l *log.Logger, proto, address string) error { Map: "Map", Port: 1000, }) + if err != nil { + return err + } addr, err := net.ResolveUDPAddr("udp4", address) if err != nil { diff --git a/go.mod b/go.mod index fb175e8..28fce3e 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,8 @@ go 1.13 require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/netdata/go-orchestrator v0.0.0-20190905093727-c793edba0e8f - github.com/stretchr/objx v0.1.1 // indirect - github.com/stretchr/testify v1.4.0 + github.com/stretchr/objx v0.3.0 // indirect + github.com/stretchr/testify v1.7.0 golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 0b54d61..7a92cda 100644 --- a/go.sum +++ b/go.sum @@ -12,9 +12,14 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As= +github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -22,3 +27,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lib/svrquery/bulk_client.go b/lib/svrquery/bulk_client.go new file mode 100644 index 0000000..7fc2782 --- /dev/null +++ b/lib/svrquery/bulk_client.go @@ -0,0 +1,54 @@ +package svrquery + +import ( + "net" + + "github.com/multiplay/go-svrquery/lib/svrquery/protocol" +) + +// BulkClient is a client which can be reused with multiple requests. +type BulkClient struct { + client *Client +} + +// NewBulkClient creates a new client with no protocol or +func NewBulkClient(options ...Option) (*BulkClient, error) { + c := &Client{ + network: DefaultNetwork, + timeout: DefaultTimeout, + } + + for _, o := range options { + if err := o(c); err != nil { + return nil, err + } + } + + return &BulkClient{client: c}, nil +} + +// Query runs a query against addr with proto and options. +func (b *BulkClient) Query(proto, addr string, options ...Option) (protocol.Responser, error) { + f, err := protocol.Get(proto) + if err != nil { + return nil, err + } + + for _, o := range options { + if err := o(b.client); err != nil { + return nil, err + } + } + + b.client.Queryer = f(b.client) + + if b.client.ua, err = net.ResolveUDPAddr(b.client.network, addr); err != nil { + return nil, err + } + + if b.client.c, err = net.DialUDP(b.client.network, nil, b.client.ua); err != nil { + return nil, err + } + + return b.client.Query() +} diff --git a/lib/svrquery/protocol/interfaces.go b/lib/svrquery/protocol/interfaces.go index 150d6d2..a4895e8 100644 --- a/lib/svrquery/protocol/interfaces.go +++ b/lib/svrquery/protocol/interfaces.go @@ -17,6 +17,11 @@ type Responser interface { MaxClients() int64 } +// Mapper represents something which can return the current map. +type Mapper interface { + Map() string +} + // Client is an interface which is implemented by types which can act a query transport. type Client interface { io.ReadWriteCloser diff --git a/lib/svrquery/protocol/sqp/types.go b/lib/svrquery/protocol/sqp/types.go index d200b3f..1454bb3 100644 --- a/lib/svrquery/protocol/sqp/types.go +++ b/lib/svrquery/protocol/sqp/types.go @@ -77,6 +77,11 @@ func (q *QueryResponse) NumClients() int64 { return int64(q.ServerInfo.CurrentPlayers) } +// Map implements protocol.Mapper. +func (q *QueryResponse) Map() string { + return q.ServerInfo.Map +} + type infoHeader struct { Name string Type DataType diff --git a/lib/svrquery/protocol/titanfall/query.go b/lib/svrquery/protocol/titanfall/query.go index 60e877b..bd01281 100644 --- a/lib/svrquery/protocol/titanfall/query.go +++ b/lib/svrquery/protocol/titanfall/query.go @@ -1,8 +1,13 @@ package titanfall import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" "encoding/binary" "fmt" + "io" "github.com/multiplay/go-svrquery/lib/svrquery/common" "github.com/multiplay/go-svrquery/lib/svrquery/protocol" @@ -12,7 +17,10 @@ import ( var ( // minLength is the smallest packet we can expect. - minLength = 26 + minLength = 26 + tagSize = 16 + packetSize = 1200 + gcmAdditionalData = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} ) type queryer struct { @@ -29,11 +37,76 @@ func newQueryer(version byte) func(c protocol.Client) protocol.Queryer { } } +// encrypt encrypts the byte buffer given to it. +func (q *queryer) encrypt(b []byte) ([]byte, error) { + keyBytes, err := base64.StdEncoding.DecodeString(q.c.Key()) + if err != nil { + return nil, fmt.Errorf("decode key: %w", err) + } + + c, err := aes.NewCipher(keyBytes) + if err != nil { + return nil, fmt.Errorf("new aes cipher: %w", err) + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, fmt.Errorf("new gcm: %w", err) + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, fmt.Errorf("read random nonce: %w", err) + } + + // This will output in the form CipherTest | Tag and will need rearranging. + cipherTextAndTag := gcm.Seal(nil, nonce, b, gcmAdditionalData) + + // Rearrange output to nonce | tag | ciphertext. + newCipherText := nonce + newCipherText = append(newCipherText, cipherTextAndTag[len(cipherTextAndTag)-tagSize:]...) + newCipherText = append(newCipherText, cipherTextAndTag[:len(cipherTextAndTag)-tagSize]...) + + return newCipherText, nil +} + +// decrypt decrypts the byte buffer given to it. +func (q *queryer) decrypt(b []byte) ([]byte, error) { + keyBytes, err := base64.StdEncoding.DecodeString(q.c.Key()) + if err != nil { + return nil, fmt.Errorf("decode key: %w", err) + } + + c, err := aes.NewCipher(keyBytes) + if err != nil { + return nil, fmt.Errorf("new aes cipher: %w", err) + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, fmt.Errorf("new gcm: %w", err) + } + + if len(b) < gcm.NonceSize() { + return nil, fmt.Errorf("incoming bytes smaller than %d", gcm.NonceSize()) + } + + nonce, tag, b := b[:gcm.NonceSize()], b[gcm.NonceSize():gcm.NonceSize()+tagSize], b[gcm.NonceSize()+tagSize:] + b = append(b, tag...) + plaintext, err := gcm.Open(nil, nonce, b, gcmAdditionalData) + if err != nil { + return nil, err + } + + return plaintext, nil +} + // Query implements protocol.Queryer. -func (q *queryer) Query() (protocol.Responser, error) { - b := make([]byte, 1200) +func (q *queryer) Query() (resp protocol.Responser, err error) { + b := make([]byte, packetSize) copy(b, q.serverInfoPkt()) + // For older query versions we use a keyed magic section to auth. For newer versions we use encryption if key := q.c.Key(); key != "" { if q.version < 5 { // If keyed data asked for bump version sent to supported version level. @@ -42,18 +115,32 @@ func (q *queryer) Query() (protocol.Responser, error) { copy(b[6:], key) } + if q.version >= 8 && q.c.Key() != "" { + b, err = q.encrypt(b) + if err != nil { + return nil, err + } + } + if _, err := q.c.Write(b); err != nil { - return nil, err + return nil, fmt.Errorf("query write: %w", err) } n, err := q.c.Read(b) if err != nil { - return nil, err + return nil, fmt.Errorf("query read: %w", err) } else if n < minLength { return nil, fmt.Errorf("packet too short (len: %d)", n) } - r := common.NewBinaryReader(b[:n], binary.LittleEndian) + if q.version >= 8 && q.c.Key() != "" { + b, err = q.decrypt(b[:n]) + if err != nil { + return nil, err + } + } + + r := common.NewBinaryReader(b, binary.LittleEndian) i := &Info{} // Header. diff --git a/lib/svrquery/protocol/titanfall/query_test.go b/lib/svrquery/protocol/titanfall/query_test.go index 1edb3ae..71b7efb 100644 --- a/lib/svrquery/protocol/titanfall/query_test.go +++ b/lib/svrquery/protocol/titanfall/query_test.go @@ -77,13 +77,26 @@ func TestQuery(t *testing.T) { } v7.TeamsLeftWithPlayersNum = 6 + v8 := v7 + v8.Version = 8 + v8.InstanceInfoV8 = InstanceInfoV8{ + Retail: 1, + InstanceType: 2, + ClientCRC: 4294967295, + NetProtocol: 526, + HealthFlags: 0, + RandomServerID: 0, + } + v8.InstanceInfo = InstanceInfo{} + cases := []struct { - name string - version byte - request string - response string - key string - expected Info + name string + version byte + request string + response string + key string + expected Info + expEncypted bool }{ { name: "v3", @@ -99,6 +112,15 @@ func TestQuery(t *testing.T) { response: "response-v7", expected: v7, }, + { + name: "v8", + version: 8, + request: "request-v8", + response: "response-v8", + expected: v8, + key: testKey, + expEncypted: true, + }, { name: "keyed", version: 5, @@ -119,20 +141,56 @@ func TestQuery(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { + var err error + mc := &clienttest.MockClient{} + mc.On("Key").Return("Z2ZkZ3Nnbmpza2U0cnRyZQ==") + p := queryer{ + c: mc, + version: tc.version, + } + req := clienttest.LoadData(t, testDir, tc.request) resp := clienttest.LoadData(t, testDir, tc.response) - m := &clienttest.MockClient{} - m.On("Write", req).Return(len(req), nil) - m.On("Read", mock.AnythingOfType("[]uint8")).Return(resp, nil) - m.On("Key").Return(tc.key) + if tc.expEncypted { + req, err = p.encrypt(req) + require.NoError(t, err) + resp, err = p.encrypt(resp) + require.NoError(t, err) + } + + mc.On("Write", mock.AnythingOfType("[]uint8")).Return(len(req), nil) + mc.On("Read", mock.AnythingOfType("[]uint8")).Return(resp, nil) - p := newQueryer(tc.version)(m) i, err := p.Query() require.NoError(t, err) require.IsType(t, &Info{}, i) require.Equal(t, &tc.expected, i) - m.AssertExpectations(t) + mc.AssertExpectations(t) }) } } + +func TestEncryptAndDecrypt(t *testing.T) { + mc := &clienttest.MockClient{} + mc.On("Key").Return("Z2ZkZ3Nnbmpza2U0cnRyZQ==") + p := queryer{ + c: mc, + } + + text := `Line 1: Some test text to be encrypted and decrypted +Line 2: Some test text to be encrypted and decrypted +Line 3: Some test text to be encrypted and decrypted +Line 4: Some test text to be encrypted and decrypted +Line 5: Some test text to be encrypted and decrypted +Line 6: Some test text to be encrypted and decrypted +Line 7: Some test text to be encrypted and decrypted +Line 8: Some test text to be encrypted and decrypted` + + encoded, err := p.encrypt([]byte(text)) + require.NoError(t, err) + + decoded, err := p.decrypt(encoded) + require.NoError(t, err) + require.Equal(t, text, string(decoded)) +} diff --git a/lib/svrquery/protocol/titanfall/testdata/request-v8 b/lib/svrquery/protocol/titanfall/testdata/request-v8 new file mode 100644 index 0000000..4c925c0 Binary files /dev/null and b/lib/svrquery/protocol/titanfall/testdata/request-v8 differ diff --git a/lib/svrquery/protocol/titanfall/testdata/response-v8 b/lib/svrquery/protocol/titanfall/testdata/response-v8 new file mode 100644 index 0000000..27830e2 Binary files /dev/null and b/lib/svrquery/protocol/titanfall/testdata/response-v8 differ diff --git a/lib/svrquery/protocol/titanfall/types.go b/lib/svrquery/protocol/titanfall/types.go index c86543c..c3eeab1 100644 --- a/lib/svrquery/protocol/titanfall/types.go +++ b/lib/svrquery/protocol/titanfall/types.go @@ -38,6 +38,11 @@ func (i Info) MaxClients() int64 { return int64(i.BasicInfo.MaxClients) } +// Map implements protocol.Mapper. +func (i Info) Map() string { + return i.BasicInfo.Map +} + // Header represents the header of a query response. type Header struct { Prefix int32 diff --git a/lib/svrsample/query.go b/lib/svrsample/query.go index 8cb019e..8fde442 100644 --- a/lib/svrsample/query.go +++ b/lib/svrsample/query.go @@ -10,8 +10,8 @@ import ( ) var ( - // ErrProtoNotFound returned when a protocol is not found - ErrProtoNotFound = errors.New("protocol not found") + // ErrProtoNotSupported returned when a protocol is not supported + ErrProtoNotSupported = errors.New("protocol not supported") ) // GetResponder gets the appropriate responder for the protocol provided @@ -20,5 +20,5 @@ func GetResponder(proto string, state common.QueryState) (common.QueryResponder, case "sqp": return sqp.NewQueryResponder(state) } - return nil, fmt.Errorf("%w: %s", ErrProtoNotFound, proto) + return nil, fmt.Errorf("%w: %s", ErrProtoNotSupported, proto) }