diff --git a/cmd/gocq/qsign.go b/cmd/gocq/qsign.go index 1bf019196..f02ea8b42 100644 --- a/cmd/gocq/qsign.go +++ b/cmd/gocq/qsign.go @@ -24,36 +24,55 @@ import ( "github.com/Mrs4s/go-cqhttp/modules/config" ) -type currentSignServer atomic.Pointer[config.SignServer] +type currentSignServer struct { + server *config.SignServer + ok bool + muRW sync.RWMutex +} -func (c *currentSignServer) get() *config.SignServer { - return (*atomic.Pointer[config.SignServer])(c).Load() +func (c *currentSignServer) get() (server *config.SignServer, status bool) { + c.muRW.RLock() + defer c.muRW.RUnlock() + if c.server == nil { + c.server = &config.SignServer{} + } + return c.server, c.ok } -func (c *currentSignServer) set(server *config.SignServer) { - (*atomic.Pointer[config.SignServer])(c).Store(server) +func (c *currentSignServer) set(server *config.SignServer, ok bool) { + c.muRW.Lock() + defer c.muRW.Unlock() + if server != nil { + c.server = server // 传入 nil 时保持原来的值 + } + c.ok = ok // 设置 server 状态 } -// 当前签名服务器 -var ss currentSignServer +// 当前正在使用的签名服务器 +var usingServer currentSignServer -// 失败计数 -type errconut atomic.Uintptr +// 连续失败计数 +type errcount atomic.Uintptr -func (ec *errconut) hasOver(count uintptr) bool { +func (ec *errcount) hasOver(count uintptr) bool { return (*atomic.Uintptr)(ec).Load() > count } -func (ec *errconut) inc() { +func (ec *errcount) increment() { (*atomic.Uintptr)(ec).Add(1) } -var errn errconut +func (ec *errcount) toZero() { + (*atomic.Uintptr)(ec).Store(0) +} + +var errn errcount // 连续找不到可用签名服务计数 +var checkMutex sync.Mutex -// getAvaliableSignServer 获取可用的签名服务器,没有则返回空和相应错误 +// getAvaliableSignServer 获取可用的签名服务器,没有则返回空结构体指针(防止panic)和相应错误 func getAvaliableSignServer() (*config.SignServer, error) { - cs := ss.get() - if cs != nil { + cs, ok := usingServer.get() + if ok { return cs, nil } if len(base.SignServers) == 0 { @@ -63,8 +82,8 @@ func getAvaliableSignServer() (*config.SignServer, error) { if maxCount == 0 { if errn.hasOver(3) { log.Warn("已连续 3 次获取不到可用签名服务器,将固定使用主签名服务器") - ss.set(&base.SignServers[0]) - return ss.get(), nil + usingServer.set(&base.SignServers[0], true) + return &base.SignServers[0], nil } } else if errn.hasOver(uintptr(maxCount)) { log.Fatalf("获取可用签名服务器失败次数超过 %v 次, 正在离线", maxCount) @@ -72,11 +91,11 @@ func getAvaliableSignServer() (*config.SignServer, error) { if len(cs.URL) > 0 { log.Warnf("当前签名服务器 %v 不可用,正在查找可用服务器", cs.URL) } - cs = asyncCheckServer(base.SignServers) - if cs == nil { - return nil, errors.New("no usable sign server") + if checkMutex.TryLock() { // 保证同时只执行一个检查,不确定不加锁会不会有别的什么问题,还是加一个( + defer checkMutex.Unlock() + return asyncCheckServer(base.SignServers) } - return cs, nil + return &config.SignServer{}, errors.New("it is checking sign servers") } func isServerAvaliable(signServer string) bool { @@ -91,21 +110,22 @@ func isServerAvaliable(signServer string) bool { return false } -// asyncCheckServer 按同步顺序检查所有签名服务器直到找到可用的 -func asyncCheckServer(servers []config.SignServer) *config.SignServer { - doRegister := sync.Once{} +// 检查所有签名服务器直到找到可用的 +func asyncCheckServer(servers []config.SignServer) (*config.SignServer, error) { + setServer := sync.Once{} wg := sync.WaitGroup{} wg.Add(len(servers)) for i, s := range servers { - go func(i int, server config.SignServer) { + log.Infof("检查签名服务器:%v (%v/%v)", s.URL, i+1, len(servers)) + go func(server config.SignServer) { defer wg.Done() - log.Infof("检查签名服务器:%v (%v/%v)", server.URL, i+1, len(servers)) if len(server.URL) < 4 { return } if isServerAvaliable(server.URL) { - doRegister.Do(func() { - ss.set(&server) + setServer.Do(func() { + errn.toZero() // 计数归零 + usingServer.set(&server, true) log.Infof("使用签名服务器 url=%v, key=%v, auth=%v", server.URL, server.Key, server.Authorization) if base.Account.AutoRegister { // 若配置了自动注册实例则在切换后注册实例,否则不需要注册,签名时由qsign自动注册 @@ -113,9 +133,14 @@ func asyncCheckServer(servers []config.SignServer) *config.SignServer { } }) } - }(i, s) + }(s) + } + wg.Wait() + s, ok := usingServer.get() + if ok { + return s, nil } - return ss.get() + return &config.SignServer{}, errors.New("no avaliable sign server") } /* @@ -126,10 +151,10 @@ func asyncCheckServer(servers []config.SignServer) *config.SignServer { */ func requestSignServer(method string, url string, headers map[string]string, body io.Reader) (string, []byte, error) { signServer, e := getAvaliableSignServer() - if e != nil && len(signServer.URL) == 0 { // 没有可用的 + if e != nil { log.Warnf("获取可用签名服务器出错:%v, 将使用主签名服务器进行签名", e) - errn.inc() - signServer = &base.SignServers[0] // 没有获取到时使用第一个 + errn.increment() // 连续错误计数 +1 + signServer = &base.SignServers[0] // 没有获取到时使用主签名服务器 } if !strings.HasPrefix(url, signServer.URL) { url = strings.TrimSuffix(signServer.URL, "/") + "/" + strings.TrimPrefix(url, "/") @@ -149,7 +174,7 @@ func requestSignServer(method string, url string, headers map[string]string, bod }.WithTimeout(time.Duration(base.SignServerTimeout) * time.Second) resp, err := req.Bytes() if err != nil { - ss.set(nil) // 标记为不可用 + usingServer.set(nil, false) // 标记当前签名服务为不可用 } return signServer.URL, resp, err } @@ -287,7 +312,7 @@ var lastToken = "" func sign(seq uint64, uin string, cmd string, qua string, buff []byte) (sign []byte, extra []byte, token []byte, err error) { i := 0 for { - cs := ss.get() + cs, _ := usingServer.get() sign, extra, token, err = signRequset(seq, uin, cmd, qua, buff) if err != nil { log.Warnf("获取sso sign时出现错误: %v. server: %v", err, cs.URL) @@ -332,7 +357,7 @@ func sign(seq uint64, uin string, cmd string, qua string, buff []byte) (sign []b } rule := base.Account.RuleChangeSignServer if (len(sign) == 0 && rule >= 1) || (len(token) == 0 && rule >= 2) { - ss.set(nil) + usingServer.set(nil, false) } return sign, extra, token, err } @@ -345,9 +370,10 @@ func signServerDestroy(uin string) error { if global.VersionNameCompare("v"+signVersion, "v1.1.6") { return errors.Errorf("当前签名服务器版本 %v 低于 1.1.6,无法使用 destroy 接口", signVersion) } + cs, _ := usingServer.get() signServer, resp, err := requestSignServer( http.MethodGet, - "destroy"+fmt.Sprintf("?uin=%v&key=%v", uin, ss.get().Key), + "destroy"+fmt.Sprintf("?uin=%v&key=%v", uin, cs.Key), nil, nil, ) if err != nil || gjson.GetBytes(resp, "code").Int() != 0 { @@ -385,9 +411,10 @@ func signStartRefreshToken(interval int64) { qqstr := strconv.FormatInt(base.Account.Uin, 10) defer t.Stop() for range t.C { - cs, master := ss.get(), base.SignServers[0] + cs, _ := usingServer.get() + master := &base.SignServers[0] if cs.URL != master.URL && isServerAvaliable(master.URL) { - ss.set(&master) + usingServer.set(master, true) log.Infof("主签名服务器可用,已切换至主签名服务器 %v", cs.URL) } err := signRefreshToken(qqstr)