Skip to content
This repository has been archived by the owner on Jul 22, 2024. It is now read-only.

Fleshed out the client protocol unit tests. #16

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
152 changes: 104 additions & 48 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Package vnc implements a VNC client.
//
// References:
// [PROTOCOL]: http://tools.ietf.org/html/rfc6143
/*
Package vnc implements a VNC client.

References:
[PROTOCOL]: http://tools.ietf.org/html/rfc6143
*/
package vnc

import (
Expand All @@ -13,6 +15,7 @@ import (
"unicode"
)

// The ClientConn type holds client connection information.
type ClientConn struct {
c net.Conn
config *ClientConfig
Expand Down Expand Up @@ -72,7 +75,23 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) {
config: cfg,
}

if err := conn.handshake(); err != nil {
if err := conn.protocolVersionHandshake(); err != nil {
conn.Close()
return nil, err
}
if err := conn.securityHandshake(); err != nil {
conn.Close()
return nil, err
}
if err := conn.securityResultHandshake(); err != nil {
conn.Close()
return nil, err
}
if err := conn.clientInit(); err != nil {
conn.Close()
return nil, err
}
if err := conn.serverInit(); err != nil {
conn.Close()
return nil, err
}
Expand Down Expand Up @@ -299,42 +318,57 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) {
return major, minor, nil
}

func (c *ClientConn) handshake() error {
const (
// Client ProtocolVersions.
PROTO_VERS_UNSUP = "UNSUPPORTED"
PROTO_VERS_3_8 = "RFB 003.008\n"
)

// protocolVersionHandshake implements §7.1.1 ProtocolVersion Handshake.
func (c *ClientConn) protocolVersionHandshake() error {
var protocolVersion [pvLen]byte

// 7.1.1, read the ProtocolVersion message sent by the server.
// Read the ProtocolVersion message sent by the server.
if _, err := io.ReadFull(c.c, protocolVersion[:]); err != nil {
return err
}

maxMajor, maxMinor, err := parseProtocolVersion(protocolVersion[:])
major, minor, err := parseProtocolVersion(protocolVersion[:])
if err != nil {
return err
}
if maxMajor < 3 {
return fmt.Errorf("unsupported major version, less than 3: %d", maxMajor)
pv := PROTO_VERS_UNSUP
if major == 3 && minor >= 8 {
pv = PROTO_VERS_3_8
}
if maxMinor < 8 {
return fmt.Errorf("unsupported minor version, less than 8: %d", maxMinor)
if pv == PROTO_VERS_UNSUP {
return NewVNCError(fmt.Sprintf("ProtocolVersion handshake failed; unsupported version '%v'", string(protocolVersion[:])))
}

// Respond with the version we will support
if _, err = c.c.Write([]byte("RFB 003.008\n")); err != nil {
if _, err = c.c.Write([]byte(pv)); err != nil {
return err
}

// 7.1.2 Security Handshake from server
return nil
}

// securityHandshake implements §7.1.2 Security Handshake.
func (c *ClientConn) securityHandshake() error {
var numSecurityTypes uint8
if err = binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil {
if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil {
return err
}

if numSecurityTypes == 0 {
return fmt.Errorf("no security types: %s", c.readErrorReason())
reason, err := c.readErrorReason()
if err != nil {
return err
}
return NewVNCError(fmt.Sprintf("Security handshake failed; no security types: %v", reason))
}

securityTypes := make([]uint8, numSecurityTypes)
if err = binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil {
if err := binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil {
return err
}

Expand All @@ -354,64 +388,72 @@ FindAuth:
}
}
}

if auth == nil {
return fmt.Errorf("no suitable auth schemes found. server supported: %#v", securityTypes)
return NewVNCError(fmt.Sprintf("Security handshake failed; no suitable auth schemes found; server supports: %#v", securityTypes))
}

// Respond back with the security type we'll use
if err = binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil {
if err := binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil {
return err
}

if err = auth.Handshake(c.c); err != nil {
if err := auth.Handshake(c.c); err != nil {
return err
}
return nil
}

// 7.1.3 SecurityResult Handshake
// securityResultHandshake implements §7.1.3 SecurityResult Handshake.
func (c *ClientConn) securityResultHandshake() error {
var securityResult uint32
if err = binary.Read(c.c, binary.BigEndian, &securityResult); err != nil {

if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil {
return err
}

if securityResult == 1 {
return fmt.Errorf("security handshake failed: %s", c.readErrorReason())
reason, err := c.readErrorReason()
if err != nil {
return err
}
return NewVNCError(fmt.Sprintf("SecurityResult handshake failed: %s", reason))
}

// 7.3.1 ClientInit
var sharedFlag uint8 = 1
if c.config.Exclusive {
sharedFlag = 0
}
return nil
}

if err = binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil {
// clientInit implements §7.3.1 ClientInit.
func (c *ClientConn) clientInit() error {
var sharedFlag uint8

if !c.config.Exclusive {
sharedFlag = 1
}
if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil {
return err
}

// 7.3.2 ServerInit
if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil {
return nil
}

// serverInit implements §7.3.2 ServerInit.
func (c *ClientConn) serverInit() error {
if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil {
return err
}

if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil {
if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil {
return err
}

// Read the pixel format
if err = readPixelFormat(c.c, &c.PixelFormat); err != nil {
if err := readPixelFormat(c.c, &c.PixelFormat); err != nil {
return err
}

var nameLength uint32
if err = binary.Read(c.c, binary.BigEndian, &nameLength); err != nil {
if err := binary.Read(c.c, binary.BigEndian, &nameLength); err != nil {
return err
}

nameBytes := make([]uint8, nameLength)
if err = binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil {
if err := binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil {
return err
}

c.DesktopName = string(nameBytes)

return nil
Expand Down Expand Up @@ -467,16 +509,30 @@ func (c *ClientConn) mainLoop() {
}
}

func (c *ClientConn) readErrorReason() string {
func (c *ClientConn) readErrorReason() (string, error) {
var reasonLen uint32
if err := binary.Read(c.c, binary.BigEndian, &reasonLen); err != nil {
return "<error>"
return "", err
}

reason := make([]uint8, reasonLen)
if err := binary.Read(c.c, binary.BigEndian, &reason); err != nil {
return "<error>"
return "", err
}

return string(reason)
return string(reason), nil
}

// VNCError implements error interface.
type VNCError struct {
s string
}

// NewVNCError returns a custom VNCError error.
func NewVNCError(s string) error {
return &VNCError{s}
}

func (e VNCError) Error() string {
return e.s
}
10 changes: 8 additions & 2 deletions client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ import (
"net"
)

const (
secTypeInvalid = iota
secTypeNone
secTypeVNCAuth
)

// A ClientAuth implements a method of authenticating with a remote server.
type ClientAuth interface {
// SecurityType returns the byte identifier sent by the server to
Expand All @@ -16,10 +22,10 @@ type ClientAuth interface {
}

// ClientAuthNone is the "none" authentication. See 7.1.2
type ClientAuthNone byte
type ClientAuthNone struct{}

func (*ClientAuthNone) SecurityType() uint8 {
return 1
return secTypeNone
}

func (*ClientAuthNone) Handshake(net.Conn) error {
Expand Down
Loading