Skip to content

Commit

Permalink
fix(qsign): nil pointer dereference
Browse files Browse the repository at this point in the history
修复启动时 `getAvaliableSignServer()` 内因指针为空导致 panic 的问题
  • Loading branch information
1umine committed Aug 29, 2023
1 parent fd6ef4a commit 71c08f5
Showing 1 changed file with 66 additions and 39 deletions.
105 changes: 66 additions & 39 deletions cmd/gocq/qsign.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,55 @@ import (
"github.com/Mrs4s/go-cqhttp/modules/config"
)

type currentSignServer atomic.Pointer[config.SignServer]
type currentSignServer struct {
server *config.SignServer
ok bool
muRW sync.RWMutex
}

func (c *currentSignServer) get() *config.SignServer {
return (*atomic.Pointer[config.SignServer])(c).Load()
func (c *currentSignServer) get() (server *config.SignServer, status bool) {
c.muRW.RLock()
defer c.muRW.RUnlock()
if c.server == nil {
c.server = &config.SignServer{}
}
return c.server, c.ok
}

func (c *currentSignServer) set(server *config.SignServer) {
(*atomic.Pointer[config.SignServer])(c).Store(server)
func (c *currentSignServer) set(server *config.SignServer, ok bool) {
c.muRW.Lock()
defer c.muRW.Unlock()
if server != nil {
c.server = server // 传入 nil 时保持原来的值
}
c.ok = ok // 设置 server 状态
}

// 当前签名服务器
var ss currentSignServer
// 当前正在使用的签名服务器
var usingServer currentSignServer

// 失败计数
type errconut atomic.Uintptr
// 连续失败计数
type errcount atomic.Uintptr

func (ec *errconut) hasOver(count uintptr) bool {
func (ec *errcount) hasOver(count uintptr) bool {
return (*atomic.Uintptr)(ec).Load() > count
}

func (ec *errconut) inc() {
func (ec *errcount) increment() {
(*atomic.Uintptr)(ec).Add(1)
}

var errn errconut
func (ec *errcount) toZero() {
(*atomic.Uintptr)(ec).Store(0)
}

var errn errcount // 连续找不到可用签名服务计数
var checkMutex sync.Mutex

// getAvaliableSignServer 获取可用的签名服务器,没有则返回空和相应错误
// getAvaliableSignServer 获取可用的签名服务器,没有则返回空结构体指针(防止panic)和相应错误
func getAvaliableSignServer() (*config.SignServer, error) {
cs := ss.get()
if cs != nil {
cs, ok := usingServer.get()
if ok {
return cs, nil
}
if len(base.SignServers) == 0 {
Expand All @@ -63,20 +82,20 @@ func getAvaliableSignServer() (*config.SignServer, error) {
if maxCount == 0 {
if errn.hasOver(3) {
log.Warn("已连续 3 次获取不到可用签名服务器,将固定使用主签名服务器")
ss.set(&base.SignServers[0])
return ss.get(), nil
usingServer.set(&base.SignServers[0], true)
return &base.SignServers[0], nil
}
} else if errn.hasOver(uintptr(maxCount)) {
log.Fatalf("获取可用签名服务器失败次数超过 %v 次, 正在离线", maxCount)
}
if len(cs.URL) > 0 {
log.Warnf("当前签名服务器 %v 不可用,正在查找可用服务器", cs.URL)
}
cs = asyncCheckServer(base.SignServers)
if cs == nil {
return nil, errors.New("no usable sign server")
if checkMutex.TryLock() { // 保证同时只执行一个检查,不确定不加锁会不会有别的什么问题,还是加一个(
defer checkMutex.Unlock()
return asyncCheckServer(base.SignServers)
}
return cs, nil
return &config.SignServer{}, errors.New("it is checking sign servers")
}

func isServerAvaliable(signServer string) bool {
Expand All @@ -91,31 +110,37 @@ func isServerAvaliable(signServer string) bool {
return false
}

// asyncCheckServer 按同步顺序检查所有签名服务器直到找到可用的
func asyncCheckServer(servers []config.SignServer) *config.SignServer {
doRegister := sync.Once{}
// 检查所有签名服务器直到找到可用的
func asyncCheckServer(servers []config.SignServer) (*config.SignServer, error) {
setServer := sync.Once{}
wg := sync.WaitGroup{}
wg.Add(len(servers))
for i, s := range servers {
go func(i int, server config.SignServer) {
log.Infof("检查签名服务器:%v (%v/%v)", s.URL, i+1, len(servers))
go func(server config.SignServer) {
defer wg.Done()
log.Infof("检查签名服务器:%v (%v/%v)", server.URL, i+1, len(servers))
if len(server.URL) < 4 {
return
}
if isServerAvaliable(server.URL) {
doRegister.Do(func() {
ss.set(&server)
setServer.Do(func() {
errn.toZero() // 计数归零
usingServer.set(&server, true)
log.Infof("使用签名服务器 url=%v, key=%v, auth=%v", server.URL, server.Key, server.Authorization)
if base.Account.AutoRegister {
// 若配置了自动注册实例则在切换后注册实例,否则不需要注册,签名时由qsign自动注册
signRegister(base.Account.Uin, device.AndroidId, device.Guid, device.QImei36, server.Key)
}
})
}
}(i, s)
}(s)
}
wg.Wait()
s, ok := usingServer.get()
if ok {
return s, nil
}
return ss.get()
return &config.SignServer{}, errors.New("no avaliable sign server")
}

/*
Expand All @@ -126,10 +151,10 @@ func asyncCheckServer(servers []config.SignServer) *config.SignServer {
*/
func requestSignServer(method string, url string, headers map[string]string, body io.Reader) (string, []byte, error) {
signServer, e := getAvaliableSignServer()
if e != nil && len(signServer.URL) == 0 { // 没有可用的
if e != nil {
log.Warnf("获取可用签名服务器出错:%v, 将使用主签名服务器进行签名", e)
errn.inc()
signServer = &base.SignServers[0] // 没有获取到时使用第一个
errn.increment() // 连续错误计数 +1
signServer = &base.SignServers[0] // 没有获取到时使用主签名服务器
}
if !strings.HasPrefix(url, signServer.URL) {
url = strings.TrimSuffix(signServer.URL, "/") + "/" + strings.TrimPrefix(url, "/")
Expand All @@ -149,7 +174,7 @@ func requestSignServer(method string, url string, headers map[string]string, bod
}.WithTimeout(time.Duration(base.SignServerTimeout) * time.Second)
resp, err := req.Bytes()
if err != nil {
ss.set(nil) // 标记为不可用
usingServer.set(nil, false) // 标记当前签名服务为不可用
}
return signServer.URL, resp, err
}
Expand Down Expand Up @@ -287,7 +312,7 @@ var lastToken = ""
func sign(seq uint64, uin string, cmd string, qua string, buff []byte) (sign []byte, extra []byte, token []byte, err error) {
i := 0
for {
cs := ss.get()
cs, _ := usingServer.get()
sign, extra, token, err = signRequset(seq, uin, cmd, qua, buff)
if err != nil {
log.Warnf("获取sso sign时出现错误: %v. server: %v", err, cs.URL)
Expand Down Expand Up @@ -332,7 +357,7 @@ func sign(seq uint64, uin string, cmd string, qua string, buff []byte) (sign []b
}
rule := base.Account.RuleChangeSignServer
if (len(sign) == 0 && rule >= 1) || (len(token) == 0 && rule >= 2) {
ss.set(nil)
usingServer.set(nil, false)
}
return sign, extra, token, err
}
Expand All @@ -345,9 +370,10 @@ func signServerDestroy(uin string) error {
if global.VersionNameCompare("v"+signVersion, "v1.1.6") {
return errors.Errorf("当前签名服务器版本 %v 低于 1.1.6,无法使用 destroy 接口", signVersion)
}
cs, _ := usingServer.get()
signServer, resp, err := requestSignServer(
http.MethodGet,
"destroy"+fmt.Sprintf("?uin=%v&key=%v", uin, ss.get().Key),
"destroy"+fmt.Sprintf("?uin=%v&key=%v", uin, cs.Key),
nil, nil,
)
if err != nil || gjson.GetBytes(resp, "code").Int() != 0 {
Expand Down Expand Up @@ -385,9 +411,10 @@ func signStartRefreshToken(interval int64) {
qqstr := strconv.FormatInt(base.Account.Uin, 10)
defer t.Stop()
for range t.C {
cs, master := ss.get(), base.SignServers[0]
cs, _ := usingServer.get()
master := &base.SignServers[0]
if cs.URL != master.URL && isServerAvaliable(master.URL) {
ss.set(&master)
usingServer.set(master, true)
log.Infof("主签名服务器可用,已切换至主签名服务器 %v", cs.URL)
}
err := signRefreshToken(qqstr)
Expand Down

0 comments on commit 71c08f5

Please sign in to comment.