Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(qsign): nil pointer dereference #2410

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 66 additions & 39 deletions cmd/gocq/qsign.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,55 @@ import (
"github.com/Mrs4s/go-cqhttp/modules/config"
)

type currentSignServer atomic.Pointer[config.SignServer]
type currentSignServer struct {
server *config.SignServer
ok bool
muRW sync.RWMutex
}

func (c *currentSignServer) get() *config.SignServer {
return (*atomic.Pointer[config.SignServer])(c).Load()
func (c *currentSignServer) get() (server *config.SignServer, status bool) {
c.muRW.RLock()
defer c.muRW.RUnlock()
if c.server == nil {
c.server = &config.SignServer{}
}
return c.server, c.ok
}

func (c *currentSignServer) set(server *config.SignServer) {
(*atomic.Pointer[config.SignServer])(c).Store(server)
func (c *currentSignServer) set(server *config.SignServer, ok bool) {
c.muRW.Lock()
defer c.muRW.Unlock()
if server != nil {
c.server = server // 传入 nil 时保持原来的值
}
c.ok = ok // 设置 server 状态
}

// 当前签名服务器
var ss currentSignServer
// 当前正在使用的签名服务器
var usingServer currentSignServer

// 失败计数
type errconut atomic.Uintptr
// 连续失败计数
type errcount atomic.Uintptr

func (ec *errconut) hasOver(count uintptr) bool {
func (ec *errcount) hasOver(count uintptr) bool {
return (*atomic.Uintptr)(ec).Load() > count
}

func (ec *errconut) inc() {
func (ec *errcount) increment() {
(*atomic.Uintptr)(ec).Add(1)
}

var errn errconut
func (ec *errcount) toZero() {
(*atomic.Uintptr)(ec).Store(0)
}

var errn errcount // 连续找不到可用签名服务计数
var checkMutex sync.Mutex

// getAvaliableSignServer 获取可用的签名服务器,没有则返回空和相应错误
// getAvaliableSignServer 获取可用的签名服务器,没有则返回空结构体指针(防止panic)和相应错误
func getAvaliableSignServer() (*config.SignServer, error) {
cs := ss.get()
if cs != nil {
cs, ok := usingServer.get()
if ok {
return cs, nil
}
if len(base.SignServers) == 0 {
Expand All @@ -63,20 +82,20 @@ func getAvaliableSignServer() (*config.SignServer, error) {
if maxCount == 0 {
if errn.hasOver(3) {
log.Warn("已连续 3 次获取不到可用签名服务器,将固定使用主签名服务器")
ss.set(&base.SignServers[0])
return ss.get(), nil
usingServer.set(&base.SignServers[0], true)
return &base.SignServers[0], nil
}
} else if errn.hasOver(uintptr(maxCount)) {
log.Fatalf("获取可用签名服务器失败次数超过 %v 次, 正在离线", maxCount)
}
if len(cs.URL) > 0 {
log.Warnf("当前签名服务器 %v 不可用,正在查找可用服务器", cs.URL)
}
cs = asyncCheckServer(base.SignServers)
if cs == nil {
return nil, errors.New("no usable sign server")
if checkMutex.TryLock() { // 保证同时只执行一个检查,不确定不加锁会不会有别的什么问题,还是加一个(
defer checkMutex.Unlock()
return asyncCheckServer(base.SignServers)
}
return cs, nil
return &config.SignServer{}, errors.New("it is checking sign servers")
}

func isServerAvaliable(signServer string) bool {
Expand All @@ -91,31 +110,37 @@ func isServerAvaliable(signServer string) bool {
return false
}

// asyncCheckServer 按同步顺序检查所有签名服务器直到找到可用的
func asyncCheckServer(servers []config.SignServer) *config.SignServer {
doRegister := sync.Once{}
// 检查所有签名服务器直到找到可用的
func asyncCheckServer(servers []config.SignServer) (*config.SignServer, error) {
setServer := sync.Once{}
wg := sync.WaitGroup{}
wg.Add(len(servers))
for i, s := range servers {
go func(i int, server config.SignServer) {
log.Infof("检查签名服务器:%v (%v/%v)", s.URL, i+1, len(servers))
go func(server config.SignServer) {
defer wg.Done()
log.Infof("检查签名服务器:%v (%v/%v)", server.URL, i+1, len(servers))
if len(server.URL) < 4 {
return
}
if isServerAvaliable(server.URL) {
doRegister.Do(func() {
ss.set(&server)
setServer.Do(func() {
errn.toZero() // 计数归零
usingServer.set(&server, true)
log.Infof("使用签名服务器 url=%v, key=%v, auth=%v", server.URL, server.Key, server.Authorization)
if base.Account.AutoRegister {
// 若配置了自动注册实例则在切换后注册实例,否则不需要注册,签名时由qsign自动注册
signRegister(base.Account.Uin, device.AndroidId, device.Guid, device.QImei36, server.Key)
}
})
}
}(i, s)
}(s)
}
wg.Wait()
s, ok := usingServer.get()
if ok {
return s, nil
}
return ss.get()
return &config.SignServer{}, errors.New("no avaliable sign server")
}

/*
Expand All @@ -126,10 +151,10 @@ func asyncCheckServer(servers []config.SignServer) *config.SignServer {
*/
func requestSignServer(method string, url string, headers map[string]string, body io.Reader) (string, []byte, error) {
signServer, e := getAvaliableSignServer()
if e != nil && len(signServer.URL) == 0 { // 没有可用的
if e != nil {
log.Warnf("获取可用签名服务器出错:%v, 将使用主签名服务器进行签名", e)
errn.inc()
signServer = &base.SignServers[0] // 没有获取到时使用第一个
errn.increment() // 连续错误计数 +1
signServer = &base.SignServers[0] // 没有获取到时使用主签名服务器
}
if !strings.HasPrefix(url, signServer.URL) {
url = strings.TrimSuffix(signServer.URL, "/") + "/" + strings.TrimPrefix(url, "/")
Expand All @@ -149,7 +174,7 @@ func requestSignServer(method string, url string, headers map[string]string, bod
}.WithTimeout(time.Duration(base.SignServerTimeout) * time.Second)
resp, err := req.Bytes()
if err != nil {
ss.set(nil) // 标记为不可用
usingServer.set(nil, false) // 标记当前签名服务为不可用
}
return signServer.URL, resp, err
}
Expand Down Expand Up @@ -287,7 +312,7 @@ var lastToken = ""
func sign(seq uint64, uin string, cmd string, qua string, buff []byte) (sign []byte, extra []byte, token []byte, err error) {
i := 0
for {
cs := ss.get()
cs, _ := usingServer.get()
sign, extra, token, err = signRequset(seq, uin, cmd, qua, buff)
if err != nil {
log.Warnf("获取sso sign时出现错误: %v. server: %v", err, cs.URL)
Expand Down Expand Up @@ -332,7 +357,7 @@ func sign(seq uint64, uin string, cmd string, qua string, buff []byte) (sign []b
}
rule := base.Account.RuleChangeSignServer
if (len(sign) == 0 && rule >= 1) || (len(token) == 0 && rule >= 2) {
ss.set(nil)
usingServer.set(nil, false)
}
return sign, extra, token, err
}
Expand All @@ -345,9 +370,10 @@ func signServerDestroy(uin string) error {
if global.VersionNameCompare("v"+signVersion, "v1.1.6") {
return errors.Errorf("当前签名服务器版本 %v 低于 1.1.6,无法使用 destroy 接口", signVersion)
}
cs, _ := usingServer.get()
signServer, resp, err := requestSignServer(
http.MethodGet,
"destroy"+fmt.Sprintf("?uin=%v&key=%v", uin, ss.get().Key),
"destroy"+fmt.Sprintf("?uin=%v&key=%v", uin, cs.Key),
nil, nil,
)
if err != nil || gjson.GetBytes(resp, "code").Int() != 0 {
Expand Down Expand Up @@ -385,9 +411,10 @@ func signStartRefreshToken(interval int64) {
qqstr := strconv.FormatInt(base.Account.Uin, 10)
defer t.Stop()
for range t.C {
cs, master := ss.get(), base.SignServers[0]
cs, _ := usingServer.get()
master := &base.SignServers[0]
if cs.URL != master.URL && isServerAvaliable(master.URL) {
ss.set(&master)
usingServer.set(master, true)
log.Infof("主签名服务器可用,已切换至主签名服务器 %v", cs.URL)
}
err := signRefreshToken(qqstr)
Expand Down
Loading