diff --git a/core/mysql.go b/core/mysql.go index c66e2628..97c6cd4b 100644 --- a/core/mysql.go +++ b/core/mysql.go @@ -136,6 +136,29 @@ func (mysql *Mysql) CleanData(id uint) error { return nil } +// GetUserByName 通过用户名来获取用户 +func (mysql *Mysql) GetUserByName(name string) *User { + db := mysql.GetDB() + if db == nil { + return nil + } + defer db.Close() + var ( + username string + originPass string + download uint64 + upload uint64 + quota int64 + id uint + ) + row := db.QueryRow(fmt.Sprintf("SELECT * FROM users WHERE username='%s'", name)) + if err := row.Scan(&id, &username, &originPass, "a, &download, &upload); err != nil { + fmt.Println(err) + return nil + } + return &User{ID: id, Username: username, Password: originPass, Download: download, Upload: upload, Quota: quota} +} + // GetData 获取用户记录 func (mysql *Mysql) GetData(ids ...string) []*User { var dataList []*User diff --git a/trojan/user.go b/trojan/user.go index 7cb916cb..fbf58b12 100644 --- a/trojan/user.go +++ b/trojan/user.go @@ -32,6 +32,10 @@ func AddUser() { fmt.Println(util.Yellow("不能新建用户名为'admin'的用户!")) return } + if _, err := core.GetValue(inputUser + "_pass"); err == nil { + fmt.Println(util.Yellow("已存在用户名为: " + inputUser + " 的用户!")) + return + } inputPass := util.Input(fmt.Sprintf("生成随机密码: %s, 使用直接回车, 否则输入自定义密码: ", randomPass), randomPass) mysql := core.GetMysql() if mysql.CreateUser(inputUser, inputPass) == nil { diff --git a/web/auth.go b/web/auth.go index 1a0173d1..78cd1c8f 100644 --- a/web/auth.go +++ b/web/auth.go @@ -44,7 +44,10 @@ func init() { } }, Authenticator: func(c *gin.Context) (interface{}, error) { - var loginVals Login + var ( + password string + loginVals Login + ) if err := c.ShouldBind(&loginVals); err != nil { return "", jwt.ErrMissingLoginValues } @@ -53,15 +56,25 @@ func init() { if err != nil { return nil, err } - if value, err := core.GetValue(userID + "_pass"); err != nil { - return nil, err - } else if value == pass { + if userID != "admin" { + mysql := core.GetMysql() + user := mysql.GetUserByName(userID) + if user == nil { + return nil, jwt.ErrFailedAuthentication + } + password = user.Password + } else { + if password, err = core.GetValue(userID + "_pass"); err != nil { + return nil, err + } + } + if password == pass { return &loginVals, nil } return nil, jwt.ErrFailedAuthentication }, Authorizator: func(data interface{}, c *gin.Context) bool { - if v, ok := data.(*Login); ok && v.Username == "admin" { + if _, ok := data.(*Login); ok { return true } return false @@ -85,7 +98,7 @@ func init() { func updateUser(c *gin.Context) { responseBody := controller.ResponseBody{Msg: "success"} defer controller.TimeCost(time.Now(), &responseBody) - username := c.DefaultPostForm("username", "admin") + username := c.PostForm("username") pass := c.PostForm("password") err := core.SetValue(fmt.Sprintf("%s_pass", username), pass) if err != nil { @@ -94,6 +107,12 @@ func updateUser(c *gin.Context) { c.JSON(200, responseBody) } +// RequestUsername 获取请求接口的用户名 +func RequestUsername(c *gin.Context) string { + claims := jwt.ExtractClaims(c) + return claims[identityKey].(string) +} + // Auth 权限router func Auth(r *gin.Engine) *jwt.GinJWTMiddleware { r.NoRoute(authMiddleware.MiddlewareFunc(), func(c *gin.Context) { diff --git a/web/controller/user.go b/web/controller/user.go index 56eeb18a..a76cf8ed 100644 --- a/web/controller/user.go +++ b/web/controller/user.go @@ -7,11 +7,19 @@ import ( ) // UserList 获取用户列表 -func UserList() *ResponseBody { +func UserList(findUser string) *ResponseBody { responseBody := ResponseBody{Msg: "success"} defer TimeCost(time.Now(), &responseBody) mysql := core.GetMysql() userList := mysql.GetData() + if findUser != "" { + for _, user := range userList { + if user.Username == findUser { + userList = []*core.User{user} + break + } + } + } if userList == nil { responseBody.Msg = "连接mysql失败!" return &responseBody @@ -35,6 +43,10 @@ func CreateUser(username string, password string) *ResponseBody { responseBody.Msg = "不能创建用户名为admin的用户!" return &responseBody } + if _, err := core.GetValue(username + "_pass"); err == nil { + responseBody.Msg = "已存在用户名为: " + username + " 的用户!" + return &responseBody + } mysql := core.GetMysql() pass, err := base64.StdEncoding.DecodeString(password) if err != nil { diff --git a/web/web.go b/web/web.go index bad69466..bf6648f4 100644 --- a/web/web.go +++ b/web/web.go @@ -16,7 +16,12 @@ func userRouter(router *gin.Engine) { user := router.Group("/trojan/user") { user.GET("", func(c *gin.Context) { - c.JSON(200, controller.UserList()) + requestUser := RequestUsername(c) + if requestUser == "admin" { + c.JSON(200, controller.UserList("")) + } else { + c.JSON(200, controller.UserList(requestUser)) + } }) user.POST("", func(c *gin.Context) { username := c.PostForm("username")