diff --git a/internal/app/server/server.go b/internal/app/server/server.go index 75b2f14..b3c693b 100644 --- a/internal/app/server/server.go +++ b/internal/app/server/server.go @@ -18,6 +18,7 @@ import ( "github.com/codecrafters-io/redis-starter-go/pkg/rdb" "github.com/codecrafters-io/redis-starter-go/pkg/resp" "github.com/codecrafters-io/redis-starter-go/pkg/telemetry/logger" + "github.com/codecrafters-io/redis-starter-go/pkg/utils" ) const ( @@ -33,6 +34,7 @@ type Server struct { clients map[*client.Client]struct{} cFactory *command.CommandFactory isMaster bool + replicas map[*client.Client]struct{} masterAddr string messageChan chan client.Message masterClient *client.Client @@ -48,20 +50,21 @@ func NewServer(cfg *config.Config) *Server { parts := strings.Split(replicaOf, " ") masterAddr = fmt.Sprintf("%s:%s", parts[0], parts[1]) } - - return &Server{ + s := &Server{ mu: sync.Mutex{}, cfg: cfg, done: make(chan struct{}), store: store, - cFactory: command.NewCommandFactory(store, cfg), clients: make(map[*client.Client]struct{}), isMaster: isMaster, + replicas: make(map[*client.Client]struct{}), masterAddr: masterAddr, messageChan: make(chan client.Message), masterClient: nil, disconnectChan: make(chan *client.Client), } + s.cFactory = command.NewCommandFactory(store, cfg, s) + return s } func (s *Server) Start(ctx context.Context) error { @@ -127,14 +130,19 @@ func (s *Server) acceptConnection(ctx context.Context) error { } } cl := client.NewClient(conn, s.messageChan) - s.mu.Lock() + cl.ID = utils.GenerateRandomAlphanumeric(40) s.clients[cl] = struct{}{} - s.mu.Unlock() go cl.HandleConnection(ctx) } } -func (s *Server) startReplication(ctx context.Context) error { +func (s *Server) startReplication(ctx context.Context) (err error) { + defer func() { + if err != nil { + s.masterClient = nil + } + }() + if s.isMaster { return nil } @@ -147,37 +155,70 @@ func (s *Server) startReplication(ctx context.Context) error { if err := s.sendPingToMaster(ctx); err != nil { return fmt.Errorf("failed to send PING to master: %v", err) } - // TODO: Implement REPLCONF and PSYNC in later stages + + if err := s.sendReplconfToMaster(ctx); err != nil { + return fmt.Errorf("failed to send REPLCONF to master: %v", err) + } return nil } func (s *Server) sendPingToMaster(ctx context.Context) error { pingCmd := resp.CreatePingCommand() - if _, err := s.masterClient.Writer.Write(pingCmd); err != nil { - return err + response, err := s.sendAndReceive(pingCmd) + if err != nil { + return fmt.Errorf("failed to read response from master: %v", err) + } + if response.Type != resp.SimpleString || response.String() != "PONG" { + return fmt.Errorf("unexpected response from master: %v", response) + } + + logger.Info(ctx, "Successfully sent PING to master and received PONG") + return nil +} + +func (s *Server) sendAndReceive(cmd []byte) (*resp.Resp, error) { + if _, err := s.masterClient.Writer.Write(cmd); err != nil { + return nil, err } if err := s.masterClient.Writer.Flush(); err != nil { - return err + return nil, err } conn := s.masterClient.Conn() - conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to set read deadline: %v", err) + } buffer := make([]byte, 1024) n, err := conn.Read(buffer) if err != nil { - return err + return nil, err } reader := resp.NewResp(bytes.NewReader(buffer[:n])) response, err := reader.ParseResp() if err != nil { - return err + return nil, err } + return response, nil +} - if response.Type != resp.SimpleString || response.String() != "PONG" { - return fmt.Errorf("unexpected response from master: %v", response) +func (s *Server) sendReplconfToMaster(ctx context.Context) error { + // Send REPLCONF listening-port + port, _ := s.cfg.Get(config.ListenAddrKey) + port = strings.TrimPrefix(port, ":") + replconfPort := resp.CreateReplconfCommand("listening-port", port) + response, err := s.sendAndReceive(replconfPort) + if err != nil || response.Type != resp.SimpleString || response.String() != "OK" { + return fmt.Errorf("unexpected response from master for REPLCONF listening-port: %v", response) } - logger.Info(ctx, "Successfully sent PING to master and received PONG") + replconfCapa := resp.CreateReplconfCommand("capa", "psync2") + response, err = s.sendAndReceive(replconfCapa) + if err != nil || response.Type != resp.SimpleString || response.String() != "OK" { + return fmt.Errorf("unexpected response from master for REPLCONF capa psync2: %v", response) + } + + logger.Info(ctx, "Successfully sent REPLCONF commands to master") return nil } @@ -216,7 +257,7 @@ func (s *Server) closeClient(ctx context.Context, cl *client.Client) { func (s *Server) handleMessage(ctx context.Context, cl *client.Client, r *resp.Resp) error { writer := cl.Writer - cmdName := r.Data.([]*resp.Resp)[0].String() + cmdName := strings.ToLower(r.Data.([]*resp.Resp)[0].String()) args := r.Data.([]*resp.Resp)[1:] logger.Info(ctx, "received command, cmd: %s, args: %v", cmdName, args) cmd, err := s.cFactory.GetCommand(cmdName) @@ -233,10 +274,35 @@ func (s *Server) handleMessage(ctx context.Context, cl *client.Client, r *resp.R cl.Writer.Reset() return s.writeError(cl, err) } + if cmdName == "replconf" && len(args) > 0 && args[0].String() == "listening-port" { + s.addReplica(cl) + } return nil } +func (s *Server) addReplica(c *client.Client) { + s.mu.Lock() + defer s.mu.Unlock() + s.replicas[c] = struct{}{} +} + +func (s *Server) GetReplicaInfo() []map[string]string { + s.mu.Lock() + defer s.mu.Unlock() + + info := make([]map[string]string, 0, len(s.replicas)) + for replica := range s.replicas { + replicaInfo := map[string]string{ + "id": replica.ID, + "addr": replica.Conn().RemoteAddr().String(), + "listening_port": replica.ListeningPort, + } + info = append(info, replicaInfo) + } + return info +} + func (s *Server) handleBlockingCommand(ctx context.Context, cl *client.Client, cmd command.Command, writer *resp.Writer, args []*resp.Resp) { err := cmd.Execute(cl, writer, args) if err != nil { diff --git a/internal/client/client.go b/internal/client/client.go index ece4a88..21ac955 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -25,10 +25,12 @@ type Info struct { } type Client struct { + ID string conn net.Conn authenticated bool info Info preferredRespVersion int + ListeningPort string bw *bufio.Writer lastInteraction time.Time disconnectChan chan *Client diff --git a/pkg/command/command.go b/pkg/command/command.go index 0b2797c..ab610e1 100644 --- a/pkg/command/command.go +++ b/pkg/command/command.go @@ -3,7 +3,6 @@ package command import ( "errors" "fmt" - "strings" "github.com/codecrafters-io/redis-starter-go/internal/app/server/config" "github.com/codecrafters-io/redis-starter-go/internal/client" @@ -15,6 +14,10 @@ var ( ErrCommandNotFound = errors.New("unknown command") ) +type ServerInfoProvider interface { + GetReplicaInfo() []map[string]string +} + type Command interface { Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) error IsBlocking(args []*resp.Resp) bool @@ -22,6 +25,7 @@ type Command interface { type CommandFactory struct { commands map[string]Command + // serverInfo ServerInfoProvider } type EchoCommand struct { @@ -53,28 +57,29 @@ func (pc *PingCommand) IsBlocking(_ []*resp.Resp) bool { return false } -func NewCommandFactory(kv keyval.KV, cfg *config.Config) *CommandFactory { +func NewCommandFactory(kv keyval.KV, cfg *config.Config, serverInfo ServerInfoProvider) *CommandFactory { return &CommandFactory{ commands: map[string]Command{ - "echo": &EchoCommand{}, - "ping": &PingCommand{}, - "set": &Set{kv: kv}, - "get": &Get{kv: kv}, - "hello": &Hello{}, - "info": &Info{cfg: cfg}, - "client": &ClientCmd{}, - "config": &ConfigCmd{cfg: cfg}, - "keys": &Keys{kv: kv}, - "type": &TypeCmd{kv: kv}, - "xadd": &XAdd{kv: kv}, - "xrange": &XRange{kv: kv}, - "xread": &XRead{kv: kv}, + "echo": &EchoCommand{}, + "ping": &PingCommand{}, + "set": &Set{kv: kv}, + "get": &Get{kv: kv}, + "hello": &Hello{}, + "info": &Info{cfg: cfg, serverInfo: serverInfo}, + "client": &ClientCmd{}, + "config": &ConfigCmd{cfg: cfg}, + "keys": &Keys{kv: kv}, + "type": &TypeCmd{kv: kv}, + "xadd": &XAdd{kv: kv}, + "xrange": &XRange{kv: kv}, + "xread": &XRead{kv: kv}, + "replconf": &ReplConf{}, }, } } func (cf *CommandFactory) GetCommand(cmd string) (Command, error) { - command, found := cf.commands[strings.ToLower(cmd)] + command, found := cf.commands[cmd] if !found { return nil, fmt.Errorf("%w: %s", ErrCommandNotFound, cmd) } diff --git a/pkg/command/info.go b/pkg/command/info.go index 510c190..3716fb6 100644 --- a/pkg/command/info.go +++ b/pkg/command/info.go @@ -15,11 +15,12 @@ const ( REPLICATION = "replication" ) -type DynamicFieldHandler func(*config.Config) string +type DynamicFieldHandler func(*config.Config, ServerInfoProvider) string type SectionInfo struct { StaticFields map[string]string DynamicFields map[string]DynamicFieldHandler + CustomBuilder func(*config.Config, ServerInfoProvider) string } var sections = map[string]SectionInfo{ @@ -37,12 +38,17 @@ var sections = map[string]SectionInfo{ "role": determineRole, "master_replid": generateMasterReplID, "master_repl_offset": getMasterReplOffset, + "connected_slaves": func(cfg *config.Config, serverInfo ServerInfoProvider) string { + return fmt.Sprintf("%d", len(serverInfo.GetReplicaInfo())) + }, }, + CustomBuilder: buildReplicaInfo, }, } type Info struct { - cfg *config.Config + cfg *config.Config + serverInfo ServerInfoProvider } func (h *Info) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) error { @@ -56,11 +62,11 @@ func (h *Info) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) err if argsLen == 1 { sectionName := args[0].String() if sectionInfo, exists := sections[sectionName]; exists { - str.WriteString(buildSectionString(sectionName, sectionInfo, h.cfg)) + str.WriteString(h.buildSectionString(sectionName, sectionInfo, h.cfg)) } } else { for sectionName, sectionInfo := range sections { - str.WriteString(buildSectionString(sectionName, sectionInfo, h.cfg)) + str.WriteString(h.buildSectionString(sectionName, sectionInfo, h.cfg)) str.WriteString("\r\n") } } @@ -68,7 +74,7 @@ func (h *Info) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) err return wr.WriteValue(str.String()) } -func buildSectionString(sectionName string, sectionInfo SectionInfo, cfg *config.Config) string { +func (h *Info) buildSectionString(sectionName string, sectionInfo SectionInfo, cfg *config.Config) string { // # Server\r\nupstash_version:1.10.5\r\n,etc. var sb strings.Builder sb.WriteString(fmt.Sprintf("# %s\r\n", sectionName)) @@ -78,10 +84,14 @@ func buildSectionString(sectionName string, sectionInfo SectionInfo, cfg *config } for key, handler := range sectionInfo.DynamicFields { - value := handler(cfg) + value := handler(cfg, h.serverInfo) sb.WriteString(fmt.Sprintf("%s:%s\r\n", key, value)) } + if sectionInfo.CustomBuilder != nil { + sb.WriteString(sectionInfo.CustomBuilder(cfg, h.serverInfo)) + } + return sb.String() } @@ -89,7 +99,7 @@ func (h *Info) IsBlocking(_ []*resp.Resp) bool { return false } -func determineRole(cfg *config.Config) string { +func determineRole(cfg *config.Config, _ ServerInfoProvider) string { _, err := cfg.Get(config.ReplicaOfKey) if err == nil { return "slave" @@ -97,10 +107,21 @@ func determineRole(cfg *config.Config) string { return "master" } -func generateMasterReplID(_ *config.Config) string { +func generateMasterReplID(_ *config.Config, _ ServerInfoProvider) string { return utils.GenerateRandomAlphanumeric(40) } -func getMasterReplOffset(_ *config.Config) string { +func getMasterReplOffset(_ *config.Config, _ ServerInfoProvider) string { return "0" } + +func buildReplicaInfo(cfg *config.Config, serverInfo ServerInfoProvider) string { + var sb strings.Builder + replicaInfo := serverInfo.GetReplicaInfo() + + for i, replica := range replicaInfo { + sb.WriteString(fmt.Sprintf("slave%d:id=%s,ip=%s,port=%s,state=online,offset=0,lag=0\r\n", + i, replica["id"], replica["addr"], replica["listening_port"])) + } + return sb.String() +} diff --git a/pkg/command/replconf.go b/pkg/command/replconf.go new file mode 100644 index 0000000..d60f24a --- /dev/null +++ b/pkg/command/replconf.go @@ -0,0 +1,42 @@ +package command + +import ( + "errors" + "fmt" + "strings" + + "github.com/codecrafters-io/redis-starter-go/internal/client" + "github.com/codecrafters-io/redis-starter-go/pkg/resp" +) + +type ReplConf struct{} + +func (rc *ReplConf) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) error { + if len(args) < 2 { + return wr.WriteError(errors.New("wrong number of arguments for 'replconf' command")) + } + + subCommand := strings.ToLower(args[0].String()) + switch subCommand { + case "listening-port": + if len(args) != 2 { + return wr.WriteError(errors.New("wrong number of arguments for 'replconf listening-port' command")) + } + port := args[1].String() + c.ListeningPort = port + case "capa": + if len(args) != 2 { + return wr.WriteError(errors.New("wrong number of arguments for 'replconf capa' command")) + } + // capability := args[1].String() + // Ignore capability for now + default: + return wr.WriteError(fmt.Errorf("unknown replconf subcommand: %s", subCommand)) + } + + return wr.WriteSimpleValue(resp.SimpleString, []byte("OK")) +} + +func (rc *ReplConf) IsBlocking(_ []*resp.Resp) bool { + return false +} diff --git a/pkg/command/xread.go b/pkg/command/xread.go index c0cde4d..9fbbf27 100644 --- a/pkg/command/xread.go +++ b/pkg/command/xread.go @@ -138,6 +138,7 @@ func (x *XRead) readStreams(opts *XReadOptions) (map[string][]keyval.StreamEntry func (x *XRead) blockingRead(opts *XReadOptions) (map[string][]keyval.StreamEntry, error) { result := make(map[string][]keyval.StreamEntry) subscriptions := make(map[string]chan keyval.StreamEntry) + defer func() { for streamName, ch := range subscriptions { stream, _ := x.kv.GetStream(streamName, false) @@ -162,9 +163,9 @@ func (x *XRead) blockingRead(opts *XReadOptions) (map[string][]keyval.StreamEntr timer = time.NewTimer(*opts.Block) defer timer.Stop() } - cases := make([]reflect.SelectCase, 0, len(subscriptions)+1) - for _, ch := range subscriptions { + for _, streamName := range x.streamOrder { + ch := subscriptions[streamName] cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}) } diff --git a/pkg/resp/command.go b/pkg/resp/command.go index 4eb0835..1363464 100644 --- a/pkg/resp/command.go +++ b/pkg/resp/command.go @@ -5,7 +5,7 @@ import "bytes" func CreateCommand(args ...string) []byte { var buf bytes.Buffer w := NewWriter(&buf, RESP3) - w.WriteStringSlice(args) + _ = w.WriteStringSlice(args) return buf.Bytes() }