Skip to content

Commit

Permalink
all: add some safety for client being nil
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Nov 14, 2024
1 parent 4dcaa22 commit 3c0f25d
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 6 deletions.
6 changes: 6 additions & 0 deletions appstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import (
// FetchAppState fetches updates to the given type of app state. If fullSync is true, the current
// cached state will be removed and all app state patches will be re-fetched from the server.
func (cli *Client) FetchAppState(name appstate.WAPatchName, fullSync, onlyIfNotSynced bool) error {
if cli == nil {
return ErrClientIsNil
}
cli.appStateSyncLock.Lock()
defer cli.appStateSyncLock.Unlock()
if fullSync {
Expand Down Expand Up @@ -347,6 +350,9 @@ func (cli *Client) requestAppStateKeys(ctx context.Context, rawKeyIDs [][]byte)
//
// cli.SendAppState(appstate.BuildMute(targetJID, true, 24 * time.Hour))
func (cli *Client) SendAppState(patch appstate.PatchInfo) error {
if cli == nil {
return ErrClientIsNil
}
version, hash, err := cli.Store.AppState.GetAppStateVersion(string(patch.Type))
if err != nil {
return err
Expand Down
23 changes: 20 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ func (cli *Client) closeSocketWaitChan() {
}

func (cli *Client) getOwnID() types.JID {
if cli == nil {
return types.EmptyJID
}
id := cli.Store.ID
if id == nil {
return types.EmptyJID
Expand All @@ -379,6 +382,9 @@ func (cli *Client) getOwnID() types.JID {
}

func (cli *Client) WaitForConnection(timeout time.Duration) bool {
if cli == nil {
return false
}
timeoutChan := time.After(timeout)
cli.socketLock.RLock()
for cli.socket == nil || !cli.socket.IsConnected() || !cli.IsLoggedIn() {
Expand All @@ -398,6 +404,9 @@ func (cli *Client) WaitForConnection(timeout time.Duration) bool {
// Connect connects the client to the WhatsApp web websocket. After connection, it will either
// authenticate if there's data in the device store, or emit a QREvent to set up a new link.
func (cli *Client) Connect() error {
if cli == nil {
return ErrClientIsNil
}
cli.socketLock.Lock()
defer cli.socketLock.Unlock()
if cli.socket != nil {
Expand Down Expand Up @@ -444,7 +453,7 @@ func (cli *Client) Connect() error {

// IsLoggedIn returns true after the client is successfully connected and authenticated on WhatsApp.
func (cli *Client) IsLoggedIn() bool {
return cli.isLoggedIn.Load()
return cli != nil && cli.isLoggedIn.Load()
}

func (cli *Client) onDisconnect(ns *socket.NoiseSocket, remote bool) {
Expand Down Expand Up @@ -508,6 +517,9 @@ func (cli *Client) autoReconnect() {
// IsConnected checks if the client is connected to the WhatsApp web websocket.
// Note that this doesn't check if the client is authenticated. See the IsLoggedIn field for that.
func (cli *Client) IsConnected() bool {
if cli == nil {
return false
}
cli.socketLock.RLock()
connected := cli.socket != nil && cli.socket.IsConnected()
cli.socketLock.RUnlock()
Expand All @@ -519,7 +531,7 @@ func (cli *Client) IsConnected() bool {
// This will not emit any events, the Disconnected event is only used when the
// connection is closed by the server or a network error.
func (cli *Client) Disconnect() {
if cli.socket == nil {
if cli == nil || cli.socket == nil {
return
}
cli.socketLock.Lock()
Expand All @@ -544,7 +556,9 @@ func (cli *Client) unlockedDisconnect() {
// Note that this will not emit any events. The LoggedOut event is only used for external logouts
// (triggered by the user from the main device or by WhatsApp servers).
func (cli *Client) Logout() error {
if cli.MessengerConfig != nil {
if cli == nil {
return ErrClientIsNil
} else if cli.MessengerConfig != nil {
return errors.New("can't logout with Messenger credentials")
}
ownID := cli.getOwnID()
Expand Down Expand Up @@ -728,6 +742,9 @@ func (cli *Client) handlerQueueLoop(ctx context.Context) {
}

func (cli *Client) sendNodeAndGetData(node waBinary.Node) ([]byte, error) {
if cli == nil {
return nil, ErrClientIsNil
}
cli.socketLock.RLock()
sock := cli.socket
cli.socketLock.RUnlock()
Expand Down
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

// Miscellaneous errors
var (
ErrClientIsNil = errors.New("client is nil")
ErrNoSession = errors.New("can't encrypt message for device: no signal session established")
ErrIQTimedOut = errors.New("info query timed out")
ErrNotConnected = errors.New("websocket not connected")
Expand Down
3 changes: 3 additions & 0 deletions mediaconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ func (mc *MediaConn) Expiry() time.Time {
}

func (cli *Client) refreshMediaConn(force bool) (*MediaConn, error) {
if cli == nil {
return nil, ErrClientIsNil
}
cli.mediaConnLock.Lock()
defer cli.mediaConnLock.Unlock()
if cli.mediaConnCache == nil || force || time.Now().After(cli.mediaConnCache.Expiry()) {
Expand Down
6 changes: 6 additions & 0 deletions msgsecret.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ type messageEncryptedSecret interface {
}

func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType, encrypted messageEncryptedSecret, origMsgKey *waCommon.MessageKey) ([]byte, error) {
if cli == nil {
return nil, ErrClientIsNil
}
pollSender, err := getOrigSenderFromKey(msg, origMsgKey)
if err != nil {
return nil, err
Expand All @@ -102,6 +105,9 @@ func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType,
}

func (cli *Client) encryptMsgSecret(chat, origSender types.JID, origMsgID types.MessageID, useCase MsgSecretType, plaintext []byte) (ciphertext, iv []byte, err error) {
if cli == nil {
return nil, nil, ErrClientIsNil
}
ownID := cli.getOwnID()
if ownID.IsEmpty() {
return nil, nil, ErrNotLoggedIn
Expand Down
3 changes: 3 additions & 0 deletions newsletter.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func (cli *Client) NewsletterSubscribeLiveUpdates(ctx context.Context, jid types
//
// This is not the same as marking the channel as read on your other devices, use the usual MarkRead function for that.
func (cli *Client) NewsletterMarkViewed(jid types.JID, serverIDs []types.MessageServerID) error {
if cli == nil {
return ErrClientIsNil
}
items := make([]waBinary.Node, len(serverIDs))
for i, id := range serverIDs {
items[i] = waBinary.Node{
Expand Down
3 changes: 3 additions & 0 deletions pair-code.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ func generateCompanionEphemeralKey() (ephemeralKeyPair *keys.KeyPair, ephemeralK
//
// See https://faq.whatsapp.com/1324084875126592 for more info
func (cli *Client) PairPhone(phone string, showPushNotification bool, clientType PairClientType, clientDisplayName string) (string, error) {
if cli == nil {
return "", ErrClientIsNil
}
ephemeralKeyPair, ephemeralKey, encodedLinkingCode := generateCompanionEphemeralKey()
phone = notNumbers.ReplaceAllString(phone, "")
if len(phone) <= 6 {
Expand Down
2 changes: 1 addition & 1 deletion privacysettings.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (cli *Client) TryFetchPrivacySettings(ignoreCache bool) (*types.PrivacySett
// GetPrivacySettings will get the user's privacy settings. If an error occurs while fetching them, the error will be
// logged, but the method will just return an empty struct.
func (cli *Client) GetPrivacySettings() (settings types.PrivacySettings) {
if cli.MessengerConfig != nil {
if cli == nil || cli.MessengerConfig != nil {
return
}
settingsPtr, err := cli.TryFetchPrivacySettings(false)
Expand Down
4 changes: 3 additions & 1 deletion qrchan.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ func (qrc *qrChannel) handleEvent(rawEvt interface{}) {
// The last value to be emitted will be a special event like "success", "timeout" or another error code
// depending on the result of the pairing. The channel will be closed immediately after one of those.
func (cli *Client) GetQRChannel(ctx context.Context) (<-chan QRChannelItem, error) {
if cli.IsConnected() {
if cli == nil {
return nil, ErrClientIsNil
} else if cli.IsConnected() {
return nil, ErrQRAlreadyConnected
} else if cli.Store.ID != nil {
return nil, ErrQRStoreContainsID
Expand Down
3 changes: 3 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ type infoQuery struct {
}

func (cli *Client) sendIQAsyncAndGetData(query *infoQuery) (<-chan *waBinary.Node, []byte, error) {
if cli == nil {
return nil, nil, ErrClientIsNil
}
if len(query.ID) == 0 {
query.ID = cli.generateRequestID()
}
Expand Down
6 changes: 5 additions & 1 deletion send.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
// msgID := cli.GenerateMessageID()
// cli.SendMessage(context.Background(), targetJID, &waProto.Message{...}, whatsmeow.SendRequestExtra{ID: msgID})
func (cli *Client) GenerateMessageID() types.MessageID {
if cli.MessengerConfig != nil {
if cli != nil && cli.MessengerConfig != nil {
return types.MessageID(strconv.FormatInt(GenerateFacebookMessageID(), 10))
}
data := make([]byte, 8, 8+20+16)
Expand Down Expand Up @@ -167,6 +167,10 @@ type SendRequestExtra struct {
// field in incoming message events to figure out what it contains is also a good way to learn how to
// send the same kind of message.
func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waE2E.Message, extra ...SendRequestExtra) (resp SendResponse, err error) {
if cli == nil {
err = ErrClientIsNil
return
}
var req SendRequestExtra
if len(extra) > 1 {
err = errors.New("only one extra parameter may be provided to SendMessage")
Expand Down
4 changes: 4 additions & 0 deletions sendfb.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func (cli *Client) SendFBMessage(
metadata *waMsgApplication.MessageApplication_Metadata,
extra ...SendRequestExtra,
) (resp SendResponse, err error) {
if cli == nil {
err = ErrClientIsNil
return
}
var req SendRequestExtra
if len(extra) > 1 {
err = errors.New("only one extra parameter may be provided to SendMessage")
Expand Down
6 changes: 6 additions & 0 deletions user.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,9 @@ type UsyncQueryExtras struct {
}

func (cli *Client) usync(ctx context.Context, jids []types.JID, mode, context string, query []waBinary.Node, extra ...UsyncQueryExtras) (*waBinary.Node, error) {
if cli == nil {
return nil, ErrClientIsNil
}
var extras UsyncQueryExtras
if len(extra) > 1 {
return nil, errors.New("only one extra parameter may be provided to usync()")
Expand Down Expand Up @@ -844,6 +847,9 @@ func (cli *Client) UpdateBlocklist(jid types.JID, action events.BlocklistChangeA
},
}},
})
if err != nil {
return nil, err
}
list, ok := resp.GetOptionalChildByTag("list")
if !ok {
return nil, &ElementMissingError{Tag: "list", In: "response to blocklist update"}
Expand Down

0 comments on commit 3c0f25d

Please sign in to comment.