diff --git a/model/all.go b/model/all.go index 6678820..4f9958f 100644 --- a/model/all.go +++ b/model/all.go @@ -2,4 +2,5 @@ package model var All = []any{ User{}, + Key{}, } diff --git a/model/dao/gen.go b/model/dao/gen.go index 9c6fe57..f23a83f 100644 --- a/model/dao/gen.go +++ b/model/dao/gen.go @@ -17,17 +17,20 @@ import ( var ( Q = new(Query) + Key *key User *user ) func SetDefault(db *gorm.DB, opts ...gen.DOOption) { *Q = *Use(db, opts...) + Key = &Q.Key User = &Q.User } func Use(db *gorm.DB, opts ...gen.DOOption) *Query { return &Query{ db: db, + Key: newKey(db, opts...), User: newUser(db, opts...), } } @@ -35,6 +38,7 @@ func Use(db *gorm.DB, opts ...gen.DOOption) *Query { type Query struct { db *gorm.DB + Key key User user } @@ -43,6 +47,7 @@ func (q *Query) Available() bool { return q.db != nil } func (q *Query) clone(db *gorm.DB) *Query { return &Query{ db: db, + Key: q.Key.clone(db), User: q.User.clone(db), } } @@ -58,16 +63,19 @@ func (q *Query) WriteDB() *Query { func (q *Query) ReplaceDB(db *gorm.DB) *Query { return &Query{ db: db, + Key: q.Key.replaceDB(db), User: q.User.replaceDB(db), } } type queryCtx struct { + Key *keyDo User *userDo } func (q *Query) WithContext(ctx context.Context) *queryCtx { return &queryCtx{ + Key: q.Key.WithContext(ctx), User: q.User.WithContext(ctx), } } diff --git a/model/dao/keys.gen.go b/model/dao/keys.gen.go new file mode 100644 index 0000000..a50c7b7 --- /dev/null +++ b/model/dao/keys.gen.go @@ -0,0 +1,331 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package dao + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "github.com/yankeguo/bunker/model" +) + +func newKey(db *gorm.DB, opts ...gen.DOOption) key { + _key := key{} + + _key.keyDo.UseDB(db, opts...) + _key.keyDo.UseModel(&model.Key{}) + + tableName := _key.keyDo.TableName() + _key.ALL = field.NewAsterisk(tableName) + _key.ID = field.NewString(tableName, "id") + _key.DisplayName = field.NewString(tableName, "display_name") + _key.UserID = field.NewString(tableName, "user_id") + _key.CreatedAt = field.NewTime(tableName, "created_at") + + _key.fillFieldMap() + + return _key +} + +type key struct { + keyDo + + ALL field.Asterisk + ID field.String + DisplayName field.String + UserID field.String + CreatedAt field.Time + + fieldMap map[string]field.Expr +} + +func (k key) Table(newTableName string) *key { + k.keyDo.UseTable(newTableName) + return k.updateTableName(newTableName) +} + +func (k key) As(alias string) *key { + k.keyDo.DO = *(k.keyDo.As(alias).(*gen.DO)) + return k.updateTableName(alias) +} + +func (k *key) updateTableName(table string) *key { + k.ALL = field.NewAsterisk(table) + k.ID = field.NewString(table, "id") + k.DisplayName = field.NewString(table, "display_name") + k.UserID = field.NewString(table, "user_id") + k.CreatedAt = field.NewTime(table, "created_at") + + k.fillFieldMap() + + return k +} + +func (k *key) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := k.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (k *key) fillFieldMap() { + k.fieldMap = make(map[string]field.Expr, 4) + k.fieldMap["id"] = k.ID + k.fieldMap["display_name"] = k.DisplayName + k.fieldMap["user_id"] = k.UserID + k.fieldMap["created_at"] = k.CreatedAt +} + +func (k key) clone(db *gorm.DB) key { + k.keyDo.ReplaceConnPool(db.Statement.ConnPool) + return k +} + +func (k key) replaceDB(db *gorm.DB) key { + k.keyDo.ReplaceDB(db) + return k +} + +type keyDo struct{ gen.DO } + +func (k keyDo) Debug() *keyDo { + return k.withDO(k.DO.Debug()) +} + +func (k keyDo) WithContext(ctx context.Context) *keyDo { + return k.withDO(k.DO.WithContext(ctx)) +} + +func (k keyDo) ReadDB() *keyDo { + return k.Clauses(dbresolver.Read) +} + +func (k keyDo) WriteDB() *keyDo { + return k.Clauses(dbresolver.Write) +} + +func (k keyDo) Session(config *gorm.Session) *keyDo { + return k.withDO(k.DO.Session(config)) +} + +func (k keyDo) Clauses(conds ...clause.Expression) *keyDo { + return k.withDO(k.DO.Clauses(conds...)) +} + +func (k keyDo) Returning(value interface{}, columns ...string) *keyDo { + return k.withDO(k.DO.Returning(value, columns...)) +} + +func (k keyDo) Not(conds ...gen.Condition) *keyDo { + return k.withDO(k.DO.Not(conds...)) +} + +func (k keyDo) Or(conds ...gen.Condition) *keyDo { + return k.withDO(k.DO.Or(conds...)) +} + +func (k keyDo) Select(conds ...field.Expr) *keyDo { + return k.withDO(k.DO.Select(conds...)) +} + +func (k keyDo) Where(conds ...gen.Condition) *keyDo { + return k.withDO(k.DO.Where(conds...)) +} + +func (k keyDo) Order(conds ...field.Expr) *keyDo { + return k.withDO(k.DO.Order(conds...)) +} + +func (k keyDo) Distinct(cols ...field.Expr) *keyDo { + return k.withDO(k.DO.Distinct(cols...)) +} + +func (k keyDo) Omit(cols ...field.Expr) *keyDo { + return k.withDO(k.DO.Omit(cols...)) +} + +func (k keyDo) Join(table schema.Tabler, on ...field.Expr) *keyDo { + return k.withDO(k.DO.Join(table, on...)) +} + +func (k keyDo) LeftJoin(table schema.Tabler, on ...field.Expr) *keyDo { + return k.withDO(k.DO.LeftJoin(table, on...)) +} + +func (k keyDo) RightJoin(table schema.Tabler, on ...field.Expr) *keyDo { + return k.withDO(k.DO.RightJoin(table, on...)) +} + +func (k keyDo) Group(cols ...field.Expr) *keyDo { + return k.withDO(k.DO.Group(cols...)) +} + +func (k keyDo) Having(conds ...gen.Condition) *keyDo { + return k.withDO(k.DO.Having(conds...)) +} + +func (k keyDo) Limit(limit int) *keyDo { + return k.withDO(k.DO.Limit(limit)) +} + +func (k keyDo) Offset(offset int) *keyDo { + return k.withDO(k.DO.Offset(offset)) +} + +func (k keyDo) Scopes(funcs ...func(gen.Dao) gen.Dao) *keyDo { + return k.withDO(k.DO.Scopes(funcs...)) +} + +func (k keyDo) Unscoped() *keyDo { + return k.withDO(k.DO.Unscoped()) +} + +func (k keyDo) Create(values ...*model.Key) error { + if len(values) == 0 { + return nil + } + return k.DO.Create(values) +} + +func (k keyDo) CreateInBatches(values []*model.Key, batchSize int) error { + return k.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (k keyDo) Save(values ...*model.Key) error { + if len(values) == 0 { + return nil + } + return k.DO.Save(values) +} + +func (k keyDo) First() (*model.Key, error) { + if result, err := k.DO.First(); err != nil { + return nil, err + } else { + return result.(*model.Key), nil + } +} + +func (k keyDo) Take() (*model.Key, error) { + if result, err := k.DO.Take(); err != nil { + return nil, err + } else { + return result.(*model.Key), nil + } +} + +func (k keyDo) Last() (*model.Key, error) { + if result, err := k.DO.Last(); err != nil { + return nil, err + } else { + return result.(*model.Key), nil + } +} + +func (k keyDo) Find() ([]*model.Key, error) { + result, err := k.DO.Find() + return result.([]*model.Key), err +} + +func (k keyDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Key, err error) { + buf := make([]*model.Key, 0, batchSize) + err = k.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (k keyDo) FindInBatches(result *[]*model.Key, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return k.DO.FindInBatches(result, batchSize, fc) +} + +func (k keyDo) Attrs(attrs ...field.AssignExpr) *keyDo { + return k.withDO(k.DO.Attrs(attrs...)) +} + +func (k keyDo) Assign(attrs ...field.AssignExpr) *keyDo { + return k.withDO(k.DO.Assign(attrs...)) +} + +func (k keyDo) Joins(fields ...field.RelationField) *keyDo { + for _, _f := range fields { + k = *k.withDO(k.DO.Joins(_f)) + } + return &k +} + +func (k keyDo) Preload(fields ...field.RelationField) *keyDo { + for _, _f := range fields { + k = *k.withDO(k.DO.Preload(_f)) + } + return &k +} + +func (k keyDo) FirstOrInit() (*model.Key, error) { + if result, err := k.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*model.Key), nil + } +} + +func (k keyDo) FirstOrCreate() (*model.Key, error) { + if result, err := k.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*model.Key), nil + } +} + +func (k keyDo) FindByPage(offset int, limit int) (result []*model.Key, count int64, err error) { + result, err = k.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = k.Offset(-1).Limit(-1).Count() + return +} + +func (k keyDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = k.Count() + if err != nil { + return + } + + err = k.Offset(offset).Limit(limit).Scan(result) + return +} + +func (k keyDo) Scan(result interface{}) (err error) { + return k.DO.Scan(result) +} + +func (k keyDo) Delete(models ...*model.Key) (result gen.ResultInfo, err error) { + return k.DO.Delete(models) +} + +func (k *keyDo) withDO(do gen.Dao) *keyDo { + k.DO = *do.(*gen.DO) + return k +} diff --git a/model/key.go b/model/key.go new file mode 100644 index 0000000..15c9e00 --- /dev/null +++ b/model/key.go @@ -0,0 +1,10 @@ +package model + +import "time" + +type Key struct { + ID string `gorm:"column:id;primarykey" json:"id"` + DisplayName string `gorm:"column:display_name" json:"display_name"` + UserID string `gorm:"column:user_id;index" json:"user_id"` + CreatedAt time.Time `gorm:"column:created_at;index" json:"created_at"` +} diff --git a/ssh.go b/ssh.go index b91b032..b255ec7 100644 --- a/ssh.go +++ b/ssh.go @@ -13,8 +13,11 @@ import ( "net" "os" "path/filepath" + "strings" "time" + "github.com/yankeguo/bunker/model" + "github.com/yankeguo/bunker/model/dao" "github.com/yankeguo/ufx" "go.uber.org/fx" "golang.org/x/crypto/ed25519" @@ -26,6 +29,8 @@ const ( sshHostKeyFileRSA = "ssh_host_rsa_key" sshHostKeyFileECDSA = "ssh_host_ecdsa_key" sshHostKeyFileEd25519 = "ssh_host_ed25519_key" + + sshExtKeyUserID = "bunker.user_id" ) type SSHServerParams struct { @@ -125,14 +130,32 @@ func (s *SSHServer) ensureHostSigners() (err error) { return } -func (s *SSHServer) HandleServerConn(conn net.Conn) { - defer conn.Close() - - var err error - +func (s *SSHServer) createServerConfig() *ssh.ServerConfig { cfg := &ssh.ServerConfig{ - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - return nil, nil + AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) { + log.Println("auth:", conn.RemoteAddr(), conn.User(), method, err) + }, + PublicKeyCallback: func(conn ssh.ConnMetadata, _key ssh.PublicKey) (perm *ssh.Permissions, err error) { + db := dao.Use(s.Database) + + var key *model.Key + if key, err = db.Key.Where(dao.Key.ID.Eq( + strings.ToLower(ssh.FingerprintSHA256(_key)), + )).First(); err != nil { + return nil, err + } + + var user *model.User + if user, err = db.User.Where(dao.User.ID.Eq(key.UserID)).First(); err != nil { + return nil, err + } + + perm = &ssh.Permissions{ + Extensions: map[string]string{ + sshExtKeyUserID: user.ID, + }, + } + return }, BannerCallback: func(conn ssh.ConnMetadata) string { return "bunker from github.com/yankeguo/bunker" @@ -143,13 +166,21 @@ func (s *SSHServer) HandleServerConn(conn net.Conn) { cfg.AddHostKey(sgn) } + return cfg +} + +func (s *SSHServer) HandleServerConn(conn net.Conn) { + defer conn.Close() + + var err error + var ( sc *ssh.ServerConn chNew <-chan ssh.NewChannel chReq <-chan *ssh.Request ) - if sc, chNew, chReq, err = ssh.NewServerConn(conn, cfg); err != nil { + if sc, chNew, chReq, err = ssh.NewServerConn(conn, s.createServerConfig()); err != nil { return }