From 2a64bc18a122bae2dd0173d73a1a923ec62cc74e Mon Sep 17 00:00:00 2001 From: GUO YANKE Date: Tue, 23 Jan 2024 11:29:57 +0800 Subject: [PATCH] feat: add model.Server --- model/all.go | 1 + model/dao/gen.go | 44 +++--- model/dao/keys.gen.go | 100 +++++++++++- model/dao/servers.gen.go | 331 +++++++++++++++++++++++++++++++++++++++ model/dao/users.gen.go | 100 +++++++++++- model/key.go | 7 +- model/server.go | 10 ++ model/user.go | 2 + ssh.go | 43 ++++- 9 files changed, 608 insertions(+), 30 deletions(-) create mode 100644 model/dao/servers.gen.go create mode 100644 model/server.go diff --git a/model/all.go b/model/all.go index 4f9958f..a34ba03 100644 --- a/model/all.go +++ b/model/all.go @@ -3,4 +3,5 @@ package model var All = []any{ User{}, Key{}, + Server{}, } diff --git a/model/dao/gen.go b/model/dao/gen.go index f23a83f..e24a19c 100644 --- a/model/dao/gen.go +++ b/model/dao/gen.go @@ -16,39 +16,44 @@ import ( ) var ( - Q = new(Query) - Key *key - User *user + Q = new(Query) + Key *key + Server *server + User *user ) func SetDefault(db *gorm.DB, opts ...gen.DOOption) { *Q = *Use(db, opts...) Key = &Q.Key + Server = &Q.Server User = &Q.User } func Use(db *gorm.DB, opts ...gen.DOOption) *Query { return &Query{ - db: db, - Key: newKey(db, opts...), - User: newUser(db, opts...), + db: db, + Key: newKey(db, opts...), + Server: newServer(db, opts...), + User: newUser(db, opts...), } } type Query struct { db *gorm.DB - Key key - User user + Key key + Server server + User user } func (q *Query) Available() bool { return q.db != nil } func (q *Query) clone(db *gorm.DB) *Query { return &Query{ - db: db, - Key: q.Key.clone(db), - User: q.User.clone(db), + db: db, + Key: q.Key.clone(db), + Server: q.Server.clone(db), + User: q.User.clone(db), } } @@ -62,21 +67,24 @@ func (q *Query) WriteDB() *Query { func (q *Query) ReplaceDB(db *gorm.DB) *Query { return &Query{ - db: db, - Key: q.Key.replaceDB(db), - User: q.User.replaceDB(db), + db: db, + Key: q.Key.replaceDB(db), + Server: q.Server.replaceDB(db), + User: q.User.replaceDB(db), } } type queryCtx struct { - Key *keyDo - User *userDo + Key *keyDo + Server *serverDo + User *userDo } func (q *Query) WithContext(ctx context.Context) *queryCtx { return &queryCtx{ - Key: q.Key.WithContext(ctx), - User: q.User.WithContext(ctx), + Key: q.Key.WithContext(ctx), + Server: q.Server.WithContext(ctx), + User: q.User.WithContext(ctx), } } diff --git a/model/dao/keys.gen.go b/model/dao/keys.gen.go index a50c7b7..0059d28 100644 --- a/model/dao/keys.gen.go +++ b/model/dao/keys.gen.go @@ -31,6 +31,24 @@ func newKey(db *gorm.DB, opts ...gen.DOOption) key { _key.DisplayName = field.NewString(tableName, "display_name") _key.UserID = field.NewString(tableName, "user_id") _key.CreatedAt = field.NewTime(tableName, "created_at") + _key.User = keyBelongsToUser{ + db: db.Session(&gorm.Session{}), + + RelationField: field.NewRelation("User", "model.User"), + Keys: struct { + field.RelationField + User struct { + field.RelationField + } + }{ + RelationField: field.NewRelation("User.Keys", "model.Key"), + User: struct { + field.RelationField + }{ + RelationField: field.NewRelation("User.Keys.User", "model.User"), + }, + }, + } _key.fillFieldMap() @@ -45,6 +63,7 @@ type key struct { DisplayName field.String UserID field.String CreatedAt field.Time + User keyBelongsToUser fieldMap map[string]field.Expr } @@ -81,11 +100,12 @@ func (k *key) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (k *key) fillFieldMap() { - k.fieldMap = make(map[string]field.Expr, 4) + k.fieldMap = make(map[string]field.Expr, 5) k.fieldMap["id"] = k.ID k.fieldMap["display_name"] = k.DisplayName k.fieldMap["user_id"] = k.UserID k.fieldMap["created_at"] = k.CreatedAt + } func (k key) clone(db *gorm.DB) key { @@ -98,6 +118,84 @@ func (k key) replaceDB(db *gorm.DB) key { return k } +type keyBelongsToUser struct { + db *gorm.DB + + field.RelationField + + Keys struct { + field.RelationField + User struct { + field.RelationField + } + } +} + +func (a keyBelongsToUser) Where(conds ...field.Expr) *keyBelongsToUser { + if len(conds) == 0 { + return &a + } + + exprs := make([]clause.Expression, 0, len(conds)) + for _, cond := range conds { + exprs = append(exprs, cond.BeCond().(clause.Expression)) + } + a.db = a.db.Clauses(clause.Where{Exprs: exprs}) + return &a +} + +func (a keyBelongsToUser) WithContext(ctx context.Context) *keyBelongsToUser { + a.db = a.db.WithContext(ctx) + return &a +} + +func (a keyBelongsToUser) Session(session *gorm.Session) *keyBelongsToUser { + a.db = a.db.Session(session) + return &a +} + +func (a keyBelongsToUser) Model(m *model.Key) *keyBelongsToUserTx { + return &keyBelongsToUserTx{a.db.Model(m).Association(a.Name())} +} + +type keyBelongsToUserTx struct{ tx *gorm.Association } + +func (a keyBelongsToUserTx) Find() (result *model.User, err error) { + return result, a.tx.Find(&result) +} + +func (a keyBelongsToUserTx) Append(values ...*model.User) (err error) { + targetValues := make([]interface{}, len(values)) + for i, v := range values { + targetValues[i] = v + } + return a.tx.Append(targetValues...) +} + +func (a keyBelongsToUserTx) Replace(values ...*model.User) (err error) { + targetValues := make([]interface{}, len(values)) + for i, v := range values { + targetValues[i] = v + } + return a.tx.Replace(targetValues...) +} + +func (a keyBelongsToUserTx) Delete(values ...*model.User) (err error) { + targetValues := make([]interface{}, len(values)) + for i, v := range values { + targetValues[i] = v + } + return a.tx.Delete(targetValues...) +} + +func (a keyBelongsToUserTx) Clear() error { + return a.tx.Clear() +} + +func (a keyBelongsToUserTx) Count() int64 { + return a.tx.Count() +} + type keyDo struct{ gen.DO } func (k keyDo) Debug() *keyDo { diff --git a/model/dao/servers.gen.go b/model/dao/servers.gen.go new file mode 100644 index 0000000..13835bb --- /dev/null +++ b/model/dao/servers.gen.go @@ -0,0 +1,331 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package dao + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "github.com/yankeguo/bunker/model" +) + +func newServer(db *gorm.DB, opts ...gen.DOOption) server { + _server := server{} + + _server.serverDo.UseDB(db, opts...) + _server.serverDo.UseModel(&model.Server{}) + + tableName := _server.serverDo.TableName() + _server.ALL = field.NewAsterisk(tableName) + _server.ID = field.NewString(tableName, "id") + _server.Address = field.NewString(tableName, "address") + _server.User = field.NewString(tableName, "user") + _server.CreatedAt = field.NewTime(tableName, "created_at") + + _server.fillFieldMap() + + return _server +} + +type server struct { + serverDo + + ALL field.Asterisk + ID field.String + Address field.String + User field.String + CreatedAt field.Time + + fieldMap map[string]field.Expr +} + +func (s server) Table(newTableName string) *server { + s.serverDo.UseTable(newTableName) + return s.updateTableName(newTableName) +} + +func (s server) As(alias string) *server { + s.serverDo.DO = *(s.serverDo.As(alias).(*gen.DO)) + return s.updateTableName(alias) +} + +func (s *server) updateTableName(table string) *server { + s.ALL = field.NewAsterisk(table) + s.ID = field.NewString(table, "id") + s.Address = field.NewString(table, "address") + s.User = field.NewString(table, "user") + s.CreatedAt = field.NewTime(table, "created_at") + + s.fillFieldMap() + + return s +} + +func (s *server) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := s.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (s *server) fillFieldMap() { + s.fieldMap = make(map[string]field.Expr, 4) + s.fieldMap["id"] = s.ID + s.fieldMap["address"] = s.Address + s.fieldMap["user"] = s.User + s.fieldMap["created_at"] = s.CreatedAt +} + +func (s server) clone(db *gorm.DB) server { + s.serverDo.ReplaceConnPool(db.Statement.ConnPool) + return s +} + +func (s server) replaceDB(db *gorm.DB) server { + s.serverDo.ReplaceDB(db) + return s +} + +type serverDo struct{ gen.DO } + +func (s serverDo) Debug() *serverDo { + return s.withDO(s.DO.Debug()) +} + +func (s serverDo) WithContext(ctx context.Context) *serverDo { + return s.withDO(s.DO.WithContext(ctx)) +} + +func (s serverDo) ReadDB() *serverDo { + return s.Clauses(dbresolver.Read) +} + +func (s serverDo) WriteDB() *serverDo { + return s.Clauses(dbresolver.Write) +} + +func (s serverDo) Session(config *gorm.Session) *serverDo { + return s.withDO(s.DO.Session(config)) +} + +func (s serverDo) Clauses(conds ...clause.Expression) *serverDo { + return s.withDO(s.DO.Clauses(conds...)) +} + +func (s serverDo) Returning(value interface{}, columns ...string) *serverDo { + return s.withDO(s.DO.Returning(value, columns...)) +} + +func (s serverDo) Not(conds ...gen.Condition) *serverDo { + return s.withDO(s.DO.Not(conds...)) +} + +func (s serverDo) Or(conds ...gen.Condition) *serverDo { + return s.withDO(s.DO.Or(conds...)) +} + +func (s serverDo) Select(conds ...field.Expr) *serverDo { + return s.withDO(s.DO.Select(conds...)) +} + +func (s serverDo) Where(conds ...gen.Condition) *serverDo { + return s.withDO(s.DO.Where(conds...)) +} + +func (s serverDo) Order(conds ...field.Expr) *serverDo { + return s.withDO(s.DO.Order(conds...)) +} + +func (s serverDo) Distinct(cols ...field.Expr) *serverDo { + return s.withDO(s.DO.Distinct(cols...)) +} + +func (s serverDo) Omit(cols ...field.Expr) *serverDo { + return s.withDO(s.DO.Omit(cols...)) +} + +func (s serverDo) Join(table schema.Tabler, on ...field.Expr) *serverDo { + return s.withDO(s.DO.Join(table, on...)) +} + +func (s serverDo) LeftJoin(table schema.Tabler, on ...field.Expr) *serverDo { + return s.withDO(s.DO.LeftJoin(table, on...)) +} + +func (s serverDo) RightJoin(table schema.Tabler, on ...field.Expr) *serverDo { + return s.withDO(s.DO.RightJoin(table, on...)) +} + +func (s serverDo) Group(cols ...field.Expr) *serverDo { + return s.withDO(s.DO.Group(cols...)) +} + +func (s serverDo) Having(conds ...gen.Condition) *serverDo { + return s.withDO(s.DO.Having(conds...)) +} + +func (s serverDo) Limit(limit int) *serverDo { + return s.withDO(s.DO.Limit(limit)) +} + +func (s serverDo) Offset(offset int) *serverDo { + return s.withDO(s.DO.Offset(offset)) +} + +func (s serverDo) Scopes(funcs ...func(gen.Dao) gen.Dao) *serverDo { + return s.withDO(s.DO.Scopes(funcs...)) +} + +func (s serverDo) Unscoped() *serverDo { + return s.withDO(s.DO.Unscoped()) +} + +func (s serverDo) Create(values ...*model.Server) error { + if len(values) == 0 { + return nil + } + return s.DO.Create(values) +} + +func (s serverDo) CreateInBatches(values []*model.Server, batchSize int) error { + return s.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (s serverDo) Save(values ...*model.Server) error { + if len(values) == 0 { + return nil + } + return s.DO.Save(values) +} + +func (s serverDo) First() (*model.Server, error) { + if result, err := s.DO.First(); err != nil { + return nil, err + } else { + return result.(*model.Server), nil + } +} + +func (s serverDo) Take() (*model.Server, error) { + if result, err := s.DO.Take(); err != nil { + return nil, err + } else { + return result.(*model.Server), nil + } +} + +func (s serverDo) Last() (*model.Server, error) { + if result, err := s.DO.Last(); err != nil { + return nil, err + } else { + return result.(*model.Server), nil + } +} + +func (s serverDo) Find() ([]*model.Server, error) { + result, err := s.DO.Find() + return result.([]*model.Server), err +} + +func (s serverDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Server, err error) { + buf := make([]*model.Server, 0, batchSize) + err = s.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (s serverDo) FindInBatches(result *[]*model.Server, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return s.DO.FindInBatches(result, batchSize, fc) +} + +func (s serverDo) Attrs(attrs ...field.AssignExpr) *serverDo { + return s.withDO(s.DO.Attrs(attrs...)) +} + +func (s serverDo) Assign(attrs ...field.AssignExpr) *serverDo { + return s.withDO(s.DO.Assign(attrs...)) +} + +func (s serverDo) Joins(fields ...field.RelationField) *serverDo { + for _, _f := range fields { + s = *s.withDO(s.DO.Joins(_f)) + } + return &s +} + +func (s serverDo) Preload(fields ...field.RelationField) *serverDo { + for _, _f := range fields { + s = *s.withDO(s.DO.Preload(_f)) + } + return &s +} + +func (s serverDo) FirstOrInit() (*model.Server, error) { + if result, err := s.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*model.Server), nil + } +} + +func (s serverDo) FirstOrCreate() (*model.Server, error) { + if result, err := s.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*model.Server), nil + } +} + +func (s serverDo) FindByPage(offset int, limit int) (result []*model.Server, count int64, err error) { + result, err = s.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = s.Offset(-1).Limit(-1).Count() + return +} + +func (s serverDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = s.Count() + if err != nil { + return + } + + err = s.Offset(offset).Limit(limit).Scan(result) + return +} + +func (s serverDo) Scan(result interface{}) (err error) { + return s.DO.Scan(result) +} + +func (s serverDo) Delete(models ...*model.Server) (result gen.ResultInfo, err error) { + return s.DO.Delete(models) +} + +func (s *serverDo) withDO(do gen.Dao) *serverDo { + s.DO = *do.(*gen.DO) + return s +} diff --git a/model/dao/users.gen.go b/model/dao/users.gen.go index 164b85b..629822f 100644 --- a/model/dao/users.gen.go +++ b/model/dao/users.gen.go @@ -33,6 +33,24 @@ func newUser(db *gorm.DB, opts ...gen.DOOption) user { _user.VisitedAt = field.NewTime(tableName, "visited_at") _user.IsAdmin = field.NewBool(tableName, "is_admin") _user.IsBlocked = field.NewBool(tableName, "is_blocked") + _user.Keys = userHasManyKeys{ + db: db.Session(&gorm.Session{}), + + RelationField: field.NewRelation("Keys", "model.Key"), + User: struct { + field.RelationField + Keys struct { + field.RelationField + } + }{ + RelationField: field.NewRelation("Keys.User", "model.User"), + Keys: struct { + field.RelationField + }{ + RelationField: field.NewRelation("Keys.User.Keys", "model.Key"), + }, + }, + } _user.fillFieldMap() @@ -49,6 +67,7 @@ type user struct { VisitedAt field.Time IsAdmin field.Bool IsBlocked field.Bool + Keys userHasManyKeys fieldMap map[string]field.Expr } @@ -87,13 +106,14 @@ func (u *user) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (u *user) fillFieldMap() { - u.fieldMap = make(map[string]field.Expr, 6) + u.fieldMap = make(map[string]field.Expr, 7) u.fieldMap["id"] = u.ID u.fieldMap["password_digest"] = u.PasswordDigest u.fieldMap["created_at"] = u.CreatedAt u.fieldMap["visited_at"] = u.VisitedAt u.fieldMap["is_admin"] = u.IsAdmin u.fieldMap["is_blocked"] = u.IsBlocked + } func (u user) clone(db *gorm.DB) user { @@ -106,6 +126,84 @@ func (u user) replaceDB(db *gorm.DB) user { return u } +type userHasManyKeys struct { + db *gorm.DB + + field.RelationField + + User struct { + field.RelationField + Keys struct { + field.RelationField + } + } +} + +func (a userHasManyKeys) Where(conds ...field.Expr) *userHasManyKeys { + if len(conds) == 0 { + return &a + } + + exprs := make([]clause.Expression, 0, len(conds)) + for _, cond := range conds { + exprs = append(exprs, cond.BeCond().(clause.Expression)) + } + a.db = a.db.Clauses(clause.Where{Exprs: exprs}) + return &a +} + +func (a userHasManyKeys) WithContext(ctx context.Context) *userHasManyKeys { + a.db = a.db.WithContext(ctx) + return &a +} + +func (a userHasManyKeys) Session(session *gorm.Session) *userHasManyKeys { + a.db = a.db.Session(session) + return &a +} + +func (a userHasManyKeys) Model(m *model.User) *userHasManyKeysTx { + return &userHasManyKeysTx{a.db.Model(m).Association(a.Name())} +} + +type userHasManyKeysTx struct{ tx *gorm.Association } + +func (a userHasManyKeysTx) Find() (result []*model.Key, err error) { + return result, a.tx.Find(&result) +} + +func (a userHasManyKeysTx) Append(values ...*model.Key) (err error) { + targetValues := make([]interface{}, len(values)) + for i, v := range values { + targetValues[i] = v + } + return a.tx.Append(targetValues...) +} + +func (a userHasManyKeysTx) Replace(values ...*model.Key) (err error) { + targetValues := make([]interface{}, len(values)) + for i, v := range values { + targetValues[i] = v + } + return a.tx.Replace(targetValues...) +} + +func (a userHasManyKeysTx) Delete(values ...*model.Key) (err error) { + targetValues := make([]interface{}, len(values)) + for i, v := range values { + targetValues[i] = v + } + return a.tx.Delete(targetValues...) +} + +func (a userHasManyKeysTx) Clear() error { + return a.tx.Clear() +} + +func (a userHasManyKeysTx) Count() int64 { + return a.tx.Count() +} + type userDo struct{ gen.DO } func (u userDo) Debug() *userDo { diff --git a/model/key.go b/model/key.go index 15c9e00..65f0822 100644 --- a/model/key.go +++ b/model/key.go @@ -3,8 +3,9 @@ package model import "time" type Key struct { - ID string `gorm:"column:id;primarykey" json:"id"` - DisplayName string `gorm:"column:display_name" json:"display_name"` - UserID string `gorm:"column:user_id;index" json:"user_id"` + ID string `gorm:"column:id;primarykey" json:"id"` + DisplayName string `gorm:"column:display_name" json:"display_name"` + UserID string `gorm:"column:user_id;index" json:"user_id"` + User User CreatedAt time.Time `gorm:"column:created_at;index" json:"created_at"` } diff --git a/model/server.go b/model/server.go new file mode 100644 index 0000000..0464a90 --- /dev/null +++ b/model/server.go @@ -0,0 +1,10 @@ +package model + +import "time" + +type Server struct { + // ID will be user@server_name + ID string `gorm:"column:id;primaryKey" json:"id" ` + Address string `gorm:"column:address" json:"address"` + CreatedAt time.Time `gorm:"column:created_at;index" json:"created_at"` +} diff --git a/model/user.go b/model/user.go index d84ac95..7529463 100644 --- a/model/user.go +++ b/model/user.go @@ -17,6 +17,8 @@ type User struct { VisitedAt time.Time `gorm:"column:visited_at;not null;index" json:"visited_at"` IsAdmin bool `gorm:"column:is_admin;not null;default:0;index" json:"is_admin"` IsBlocked bool `gorm:"column:is_blocked;not null;default:0;index" json:"is_blocked"` + + Keys []Key } // SetPassword update password for user diff --git a/ssh.go b/ssh.go index f8c0b8b..4945f7a 100644 --- a/ssh.go +++ b/ssh.go @@ -17,7 +17,12 @@ import ( ) const ( - sshExtKeyUserID = "bunker.user_id" + sshExtKeyError = "bunker.error" + sshExtKeyUserID = "bunker.user_id" + sshExtKeyServerID = "bunker.server_id" + sshExtKeyServerUser = "bunker.server_user" + sshExtKeyServerName = "bunker.server_name" + sshExtKeyServerAddress = "bunker.server_address" ) type SSHServerParams struct { @@ -56,27 +61,51 @@ func (s *SSHServer) createServerConfig() *ssh.ServerConfig { PublicKeyCallback: func(conn ssh.ConnMetadata, _key ssh.PublicKey) (perm *ssh.Permissions, err error) { db := dao.Use(s.Database) + // find user key and user var key *model.Key if key, err = db.Key.Where(dao.Key.ID.Eq( strings.ToLower(ssh.FingerprintSHA256(_key)), - )).First(); err != nil { + )).Preload(dao.Key.User).First(); err != nil { return nil, err } - var user *model.User - if user, err = db.User.Where(dao.User.ID.Eq(key.UserID)).First(); err != nil { - return nil, err + if key.User.ID == "" { + err = errors.New("key is not associated with any user") + return + } + + // find server + splits := strings.Split(conn.User(), "@") + if len(splits) != 2 { + err = errors.New("invalid user format, should be user@server") + return } + var ( + serverUser = splits[0] + serverName = splits[1] + ) + + var server *model.Server + if server, err = db.Server.Where(dao.Server.ID.Eq(serverUser + "@" + serverName)).First(); err != nil { + return + } + + //TODO: VALIDATE GRANT + perm = &ssh.Permissions{ Extensions: map[string]string{ - sshExtKeyUserID: user.ID, + sshExtKeyUserID: key.User.ID, + sshExtKeyServerID: server.ID, + sshExtKeyServerAddress: server.Address, + sshExtKeyServerUser: serverUser, + sshExtKeyServerName: serverName, }, } return }, BannerCallback: func(conn ssh.ConnMetadata) string { - return "bunker from github.com/yankeguo/bunker" + return "[bunker] " }, }