From 16ab8bd3cb7db38c39b611edd980585737d36a88 Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Wed, 4 Dec 2024 16:40:55 +0530 Subject: [PATCH 1/8] refactored to separate commandHandler from ioThread --- config/config.go | 1 + integration_tests/commands/resp/setup.go | 7 +- .../cmd_compose.go | 4 +- .../cmd_custom.go | 6 +- .../cmd_decompose.go | 40 +- .../{iothread => commandhandler}/cmd_meta.go | 16 +- .../cmd_preprocess.go | 14 +- internal/commandhandler/commandhandler.go | 529 +++++++++++++++++ internal/commandhandler/manager.go | 70 +++ internal/errors/errors.go | 2 +- internal/iothread/iothread.go | 548 +----------------- internal/iothread/manager.go | 18 +- internal/ops/store_op.go | 4 +- internal/server/httpServer.go | 25 +- internal/server/resp/server.go | 60 +- internal/server/websocketServer.go | 19 +- internal/shard/shard_manager.go | 11 +- internal/shard/shard_thread.go | 48 +- main.go | 7 +- 19 files changed, 782 insertions(+), 647 deletions(-) rename internal/{iothread => commandhandler}/cmd_compose.go (98%) rename internal/{iothread => commandhandler}/cmd_custom.go (93%) rename internal/{iothread => commandhandler}/cmd_decompose.go (78%) rename internal/{iothread => commandhandler}/cmd_meta.go (96%) rename internal/{iothread => commandhandler}/cmd_preprocess.go (81%) create mode 100644 internal/commandhandler/commandhandler.go create mode 100644 internal/commandhandler/manager.go diff --git a/config/config.go b/config/config.go index 98d1fc64f..0c4d9f58c 100644 --- a/config/config.go +++ b/config/config.go @@ -136,6 +136,7 @@ type performance struct { ShardCronFrequency time.Duration `config:"shard_cron_frequency" default:"1s"` MultiplexerPollTimeout time.Duration `config:"multiplexer_poll_timeout" default:"100ms"` MaxClients int32 `config:"max_clients" default:"20000" validate:"min=0"` + MaxCmdHandlers int32 `config:"max_cmd_handlers" default:"20000" validate:"min=0"` StoreMapInitSize int `config:"store_map_init_size" default:"1024000"` AdhocReqChanBufSize int `config:"adhoc_req_chan_buf_size" default:"20"` EnableProfiling bool `config:"profiling" default:"false"` diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go index b98cdd31c..eafdc554f 100644 --- a/integration_tests/commands/resp/setup.go +++ b/integration_tests/commands/resp/setup.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/iothread" "github.com/dicedb/dice/internal/server/resp" "github.com/dicedb/dice/internal/wal" @@ -196,10 +197,12 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { cmdWatchSubscriptionChan := make(chan watchmanager.WatchSubscription) gec := make(chan error) shardManager := shard.NewShardManager(1, queryWatchChan, cmdWatchChan, gec) - ioThreadManager := iothread.NewManager(20000, shardManager) + ioThreadManager := iothread.NewManager(20000) + cmdHandlerManager := commandhandler.NewManager(20000, shardManager) + // Initialize the RESP Server wl, _ := wal.NewNullWAL() - testServer := resp.NewServer(shardManager, ioThreadManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, wl) + testServer := resp.NewServer(shardManager, ioThreadManager, cmdHandlerManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, wl) ctx, cancel := context.WithCancel(context.Background()) fmt.Println("Starting the test server on port", config.DiceConfig.RespServer.Port) diff --git a/internal/iothread/cmd_compose.go b/internal/commandhandler/cmd_compose.go similarity index 98% rename from internal/iothread/cmd_compose.go rename to internal/commandhandler/cmd_compose.go index 475269452..50c5753c2 100644 --- a/internal/iothread/cmd_compose.go +++ b/internal/commandhandler/cmd_compose.go @@ -1,4 +1,4 @@ -package iothread +package commandhandler import ( "math" @@ -8,7 +8,7 @@ import ( "github.com/dicedb/dice/internal/ops" ) -// This file contains functions used by the IOThread to handle and process responses +// This file contains functions used by the CommandHandler to handle and process responses // from multiple shards during distributed operations. For commands that are executed // across several shards, such as MultiShard commands, dedicated functions are responsible // for aggregating and managing the results. diff --git a/internal/iothread/cmd_custom.go b/internal/commandhandler/cmd_custom.go similarity index 93% rename from internal/iothread/cmd_custom.go rename to internal/commandhandler/cmd_custom.go index 39dc32304..3a0952306 100644 --- a/internal/iothread/cmd_custom.go +++ b/internal/commandhandler/cmd_custom.go @@ -1,4 +1,4 @@ -package iothread +package commandhandler import ( "fmt" @@ -12,7 +12,7 @@ import ( // RespAuth returns with an encoded "OK" if the user is authenticated // If the user is not authenticated, it returns with an encoded error message -func (t *BaseIOThread) RespAuth(args []string) interface{} { +func (h *BaseCommandHandler) RespAuth(args []string) interface{} { // Check for incorrect number of arguments (arity error). if len(args) < 1 || len(args) > 2 { return diceerrors.ErrWrongArgumentCount("AUTH") @@ -31,7 +31,7 @@ func (t *BaseIOThread) RespAuth(args []string) interface{} { username, password = args[0], args[1] } - if err := t.Session.Validate(username, password); err != nil { + if err := h.Session.Validate(username, password); err != nil { return err } diff --git a/internal/iothread/cmd_decompose.go b/internal/commandhandler/cmd_decompose.go similarity index 78% rename from internal/iothread/cmd_decompose.go rename to internal/commandhandler/cmd_decompose.go index d09196d88..fb40d0a16 100644 --- a/internal/iothread/cmd_decompose.go +++ b/internal/commandhandler/cmd_decompose.go @@ -1,4 +1,4 @@ -package iothread +package commandhandler import ( "context" @@ -12,7 +12,7 @@ import ( "github.com/dicedb/dice/internal/store" ) -// This file is utilized by the IOThread to decompose commands that need to be executed +// This file is utilized by the CommandHandler to decompose commands that need to be executed // across multiple shards. For commands that operate on multiple keys or necessitate // distribution across shards (e.g., MultiShard commands), a Breakup function is invoked // to transform the original command into multiple smaller commands, each directed at @@ -25,13 +25,13 @@ import ( // decomposeRename breaks down the RENAME command into separate DELETE and SET commands. // It first waits for the result of a GET command from shards. If successful, it removes // the old key using a DEL command and sets the new key with the retrieved value using a SET command. -func decomposeRename(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeRename(ctx context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { // Waiting for GET command response var val string select { case <-ctx.Done(): - slog.Error("IOThread timed out waiting for response from shards", slog.String("id", thread.id), slog.Any("error", ctx.Err())) - case preProcessedResp, ok := <-thread.preprocessingChan: + slog.Error("CommandHandler timed out waiting for response from shards", slog.String("id", h.id), slog.Any("error", ctx.Err())) + case preProcessedResp, ok := <-h.preprocessingChan: if ok { evalResp := preProcessedResp.EvalResponse if evalResp.Error != nil { @@ -69,13 +69,13 @@ func decomposeRename(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCm // decomposeCopy breaks down the COPY command into a SET command that copies a value from // one key to another. It first retrieves the value of the original key from shards, then // sets the value to the destination key using a SET command. -func decomposeCopy(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeCopy(ctx context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { // Waiting for GET command response var resp *ops.StoreResponse select { case <-ctx.Done(): - slog.Error("IOThread timed out waiting for response from shards", slog.String("id", thread.id), slog.Any("error", ctx.Err())) - case preProcessedResp, ok := <-thread.preprocessingChan: + slog.Error("CommandHandler timed out waiting for response from shards", slog.String("id", h.id), slog.Any("error", ctx.Err())) + case preProcessedResp, ok := <-h.preprocessingChan: if ok { resp = preProcessedResp } @@ -108,7 +108,7 @@ func decomposeCopy(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) // decomposeMSet decomposes the MSET (Multi-set) command into individual SET commands. // It expects an even number of arguments (key-value pairs). For each pair, it creates // a separate SET command to store the value at the given key. -func decomposeMSet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeMSet(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args)%2 != 0 { return nil, diceerrors.ErrWrongArgumentCount("MSET") } @@ -132,7 +132,7 @@ func decomposeMSet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cm // decomposeMGet decomposes the MGET (Multi-get) command into individual GET commands. // It expects a list of keys, and for each key, it creates a separate GET command to // retrieve the value associated with that key. -func decomposeMGet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeMGet(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 1 { return nil, diceerrors.ErrWrongArgumentCount("MGET") } @@ -148,7 +148,7 @@ func decomposeMGet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cm return decomposedCmds, nil } -func decomposeSInter(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeSInter(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 1 { return nil, diceerrors.ErrWrongArgumentCount("SINTER") } @@ -164,7 +164,7 @@ func decomposeSInter(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]* return decomposedCmds, nil } -func decomposeSDiff(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeSDiff(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 1 { return nil, diceerrors.ErrWrongArgumentCount("SDIFF") } @@ -180,7 +180,7 @@ func decomposeSDiff(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*c return decomposedCmds, nil } -func decomposeJSONMget(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeJSONMget(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 2 { return nil, diceerrors.ErrWrongArgumentCount("JSON.MGET") } @@ -199,7 +199,7 @@ func decomposeJSONMget(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([ return decomposedCmds, nil } -func decomposeTouch(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeTouch(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) == 0 { return nil, diceerrors.ErrWrongArgumentCount("TOUCH") } @@ -216,13 +216,13 @@ func decomposeTouch(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*c return decomposedCmds, nil } -func decomposeDBSize(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeDBSize(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) > 0 { return nil, diceerrors.ErrWrongArgumentCount("DBSIZE") } decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args)) - for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ { + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { decomposedCmds = append(decomposedCmds, &cmd.DiceDBCmd{ Cmd: store.SingleShardSize, @@ -233,13 +233,13 @@ func decomposeDBSize(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) return decomposedCmds, nil } -func decomposeKeys(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeKeys(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) != 1 { return nil, diceerrors.ErrWrongArgumentCount("KEYS") } decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args)) - for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ { + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { decomposedCmds = append(decomposedCmds, &cmd.DiceDBCmd{ Cmd: store.SingleShardKeys, @@ -250,13 +250,13 @@ func decomposeKeys(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ( return decomposedCmds, nil } -func decomposeFlushDB(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeFlushDB(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) > 1 { return nil, diceerrors.ErrWrongArgumentCount("FLUSHDB") } decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args)) - for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ { + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { decomposedCmds = append(decomposedCmds, &cmd.DiceDBCmd{ Cmd: store.FlushDB, diff --git a/internal/iothread/cmd_meta.go b/internal/commandhandler/cmd_meta.go similarity index 96% rename from internal/iothread/cmd_meta.go rename to internal/commandhandler/cmd_meta.go index bda33d29f..a97de90bd 100644 --- a/internal/iothread/cmd_meta.go +++ b/internal/commandhandler/cmd_meta.go @@ -1,4 +1,4 @@ -package iothread +package commandhandler import ( "context" @@ -191,12 +191,12 @@ const ( type CmdMeta struct { CmdType - Cmd string - IOThreadHandler func([]string) []byte + Cmd string + CmdHandlerFunction func([]string) []byte // decomposeCommand is a function that takes a DiceDB command and breaks it down into smaller, // manageable DiceDB commands for each shard processing. It returns a slice of DiceDB commands. - decomposeCommand func(ctx context.Context, thread *BaseIOThread, DiceDBCmd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) + decomposeCommand func(ctx context.Context, h *BaseCommandHandler, DiceDBCmd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) // composeResponse is a function that combines multiple responses from the execution of commands // into a single response object. It accepts a variadic parameter of EvalResponse objects @@ -211,10 +211,10 @@ type CmdMeta struct { // preProcessResponse is a function that handles the preprocessing of a DiceDB command by // preparing the necessary operations (e.g., fetching values from shards) before the command - // is executed. It takes the io-thread and the original DiceDB command as parameters and + // is executed. It takes the CommandHandler and the original DiceDB command as parameters and // ensures that any required information is retrieved and processed in advance. Use this when set // preProcessingReq = true. - preProcessResponse func(thread *BaseIOThread, DiceDBCmd *cmd.DiceDBCmd) error + preProcessResponse func(h *BaseCommandHandler, DiceDBCmd *cmd.DiceDBCmd) error } var CommandsMeta = map[string]CmdMeta{ @@ -651,8 +651,8 @@ func init() { func validateCmdMeta(c string, meta CmdMeta) error { switch meta.CmdType { case Global: - if meta.IOThreadHandler == nil { - return fmt.Errorf("global command %s must have IOThreadHandler function", c) + if meta.CmdHandlerFunction == nil { + return fmt.Errorf("global command %s must have CmdHandlerFunction function", c) } case MultiShard, AllShard: if meta.decomposeCommand == nil || meta.composeResponse == nil { diff --git a/internal/iothread/cmd_preprocess.go b/internal/commandhandler/cmd_preprocess.go similarity index 81% rename from internal/iothread/cmd_preprocess.go rename to internal/commandhandler/cmd_preprocess.go index 0bacdc8b5..3287eb937 100644 --- a/internal/iothread/cmd_preprocess.go +++ b/internal/commandhandler/cmd_preprocess.go @@ -1,4 +1,4 @@ -package iothread +package commandhandler import ( "github.com/dicedb/dice/internal/cmd" @@ -9,13 +9,13 @@ import ( // preProcessRename prepares the RENAME command for preprocessing by sending a GET command // to retrieve the value of the original key. The retrieved value is used later in the // decomposeRename function to delete the old key and set the new key. -func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { +func preProcessRename(h *BaseCommandHandler, diceDBCmd *cmd.DiceDBCmd) error { if len(diceDBCmd.Args) < 2 { return diceerrors.ErrWrongArgumentCount("RENAME") } key := diceDBCmd.Args[0] - sid, rc := thread.shardManager.GetShardInfo(key) + sid, rc := h.shardManager.GetShardInfo(key) preCmd := cmd.DiceDBCmd{ Cmd: "RENAME", @@ -26,7 +26,7 @@ func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { SeqID: 0, RequestID: GenerateUniqueRequestID(), Cmd: &preCmd, - IOThreadID: thread.id, + CmdHandlerID: h.id, ShardID: sid, Client: nil, PreProcessing: true, @@ -38,12 +38,12 @@ func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { // preProcessCopy prepares the COPY command for preprocessing by sending a GET command // to retrieve the value of the original key. The retrieved value is used later in the // decomposeCopy function to copy the value to the destination key. -func customProcessCopy(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { +func customProcessCopy(h *BaseCommandHandler, diceDBCmd *cmd.DiceDBCmd) error { if len(diceDBCmd.Args) < 2 { return diceerrors.ErrWrongArgumentCount("COPY") } - sid, rc := thread.shardManager.GetShardInfo(diceDBCmd.Args[0]) + sid, rc := h.shardManager.GetShardInfo(diceDBCmd.Args[0]) preCmd := cmd.DiceDBCmd{ Cmd: "COPY", @@ -55,7 +55,7 @@ func customProcessCopy(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { SeqID: 0, RequestID: GenerateUniqueRequestID(), Cmd: &preCmd, - IOThreadID: thread.id, + CmdHandlerID: h.id, ShardID: sid, Client: nil, PreProcessing: true, diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go new file mode 100644 index 000000000..6313cac2e --- /dev/null +++ b/internal/commandhandler/commandhandler.go @@ -0,0 +1,529 @@ +package commandhandler + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "strconv" + "sync/atomic" + "syscall" + "time" + + "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/auth" + "github.com/dicedb/dice/internal/clientio" + "github.com/dicedb/dice/internal/clientio/requestparser" + "github.com/dicedb/dice/internal/cmd" + diceerrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/ops" + "github.com/dicedb/dice/internal/querymanager" + "github.com/dicedb/dice/internal/shard" + "github.com/dicedb/dice/internal/wal" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/google/uuid" +) + +const defaultRequestTimeout = 6 * time.Second + +var requestCounter uint32 + +type CommandHandler interface { + ID() string + Start(ctx context.Context) error + Stop() error +} + +type BaseCommandHandler struct { + CommandHandler + id string + parser requestparser.Parser + shardManager *shard.ShardManager + adhocReqChan chan *cmd.DiceDBCmd + Session *auth.Session + globalErrorChan chan error + ioThreadReadChan chan []byte // Channel to receive data from io-thread + ioThreadWriteChan chan interface{} // Channel to send data to io-thread + responseChan chan *ops.StoreResponse // Channel to communicate with shard + preprocessingChan chan *ops.StoreResponse // Channel to communicate with shard + cmdWatchSubscriptionChan chan watchmanager.WatchSubscription + wl wal.AbstractWAL +} + +func NewCommandHandler(id string, responseChan, preprocessingChan chan *ops.StoreResponse, + cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, + parser requestparser.Parser, shardManager *shard.ShardManager, gec chan error, + ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, + wl wal.AbstractWAL) *BaseCommandHandler { + return &BaseCommandHandler{ + id: id, + parser: parser, + shardManager: shardManager, + adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), + Session: auth.NewSession(), + globalErrorChan: gec, + ioThreadReadChan: ioThreadReadChan, + ioThreadWriteChan: ioThreadWriteChan, + responseChan: responseChan, + preprocessingChan: preprocessingChan, + cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, + wl: wl, + } +} + +func (h *BaseCommandHandler) ID() string { + return h.id +} + +func (h *BaseCommandHandler) Start(ctx context.Context) error { + errChan := make(chan error, 1) // for adhoc request processing errors + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case cmdReq := <-h.adhocReqChan: + h.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout) + case err := <-errChan: + return h.handleError(err) + case data := <-h.ioThreadReadChan: + if err := h.processCommand(ctx, &data, h.globalErrorChan); err != nil { + return err + } + } + } +} + +// processCommand processes commands recevied from io thread +func (h *BaseCommandHandler) processCommand(ctx context.Context, data *[]byte, errChan chan error) error { + commands, err := h.parser.Parse(*data) + + if err != nil { + slog.Debug("error parsing commands from io thread", slog.String("id", h.id), slog.Any("error", err)) + h.ioThreadWriteChan <- err + return nil + } + + if len(commands) == 0 { + slog.Debug("invalid request from io thread with zero length", slog.String("id", h.id)) + h.ioThreadWriteChan <- fmt.Errorf("ERR: Invalid request") + return nil + } + + // DiceDB supports clients to send only one request at a time + // We also need to ensure that the client is blocked until the response is received + if len(commands) > 1 { + h.ioThreadWriteChan <- fmt.Errorf("ERR: Multiple commands not supported") + return nil + } + + err = h.isAuthenticated(commands[0]) + if err != nil { + slog.Debug("command handler authentication failed", slog.String("id", h.id), slog.Any("error", err)) + h.ioThreadWriteChan <- err + return nil + } + + h.handleCmdRequestWithTimeout(ctx, errChan, commands, false, defaultRequestTimeout) + return nil +} + +func (h *BaseCommandHandler) handleCmdRequestWithTimeout(ctx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) { + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + h.executeCommandHandler(execCtx, errChan, commands, isWatchNotification) +} + +func (h *BaseCommandHandler) handleError(err error) error { + if err != nil { + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) + return err + } + } + return fmt.Errorf("error writing response: %v", err) +} + +func (h *BaseCommandHandler) executeCommandHandler(execCtx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) { + // Retrieve metadata for the command to determine if multisharding is supported. + meta, ok := CommandsMeta[commands[0].Cmd] + if ok && meta.preProcessing { + if err := meta.preProcessResponse(h, commands[0]); err != nil { + slog.Debug("error pre processing response", slog.String("id", h.id), slog.Any("error", err)) + h.ioThreadWriteChan <- err + } + } + + err := h.executeCommand(execCtx, commands[0], isWatchNotification) + if err != nil { + slog.Error("Error executing command", slog.String("id", h.id), slog.Any("error", err)) + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { + slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) + errChan <- err + } + h.ioThreadWriteChan <- err + } +} + +func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error { + // Break down the single command into multiple commands if multisharding is supported. + // The length of cmdList helps determine how many shards to wait for responses. + cmdList := make([]*cmd.DiceDBCmd, 0) + var watchLabel string + + // Retrieve metadata for the command to determine if multisharding is supported. + meta, ok := CommandsMeta[diceDBCmd.Cmd] + if !ok { + // If no metadata exists, treat it as a single command and not migrated + cmdList = append(cmdList, diceDBCmd) + } else { + // Depending on the command type, decide how to handle it. + switch meta.CmdType { + case Global: + // If it's a global command, process it immediately without involving any shards. + h.ioThreadWriteChan <- meta.CmdHandlerFunction(diceDBCmd.Args) + return nil + + case SingleShard: + // For single-shard or custom commands, process them without breaking up. + cmdList = append(cmdList, diceDBCmd) + + case MultiShard, AllShard: + var err error + // If the command supports multisharding, break it down into multiple commands. + cmdList, err = meta.decomposeCommand(ctx, h, diceDBCmd) + if err != nil { + slog.Debug("error decomposing command", slog.String("id", h.id), slog.Any("error", err)) + // Check if it's a CustomError + var customErr *diceerrors.PreProcessError + if errors.As(err, &customErr) { + h.ioThreadWriteChan <- customErr.Result + } else { + h.ioThreadWriteChan <- err + } + return nil + } + + case Custom: + return h.handleCustomCommands(diceDBCmd) + + case Watch: + // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass + // it along as is. + // Modify the command name to remove the .WATCH suffix, this will allow us to generate a consistent + // fingerprint (which uses the command name without the suffix) + diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6] + + // check if the last argument is a watch label + label := diceDBCmd.Args[len(diceDBCmd.Args)-1] + if _, err := uuid.Parse(label); err == nil { + watchLabel = label + + // remove the watch label from the args + diceDBCmd.Args = diceDBCmd.Args[:len(diceDBCmd.Args)-1] + } + + watchCmd := &cmd.DiceDBCmd{ + Cmd: diceDBCmd.Cmd, + Args: diceDBCmd.Args, + } + cmdList = append(cmdList, watchCmd) + isWatchNotification = true + + case Unwatch: + // Generate the Cmd being unwatched. All we need to do is remove the .UNWATCH suffix from the command and pass + // it along as is. + // Modify the command name to remove the .UNWATCH suffix, this will allow us to generate a consistent + // fingerprint (which uses the command name without the suffix) + diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-8] + watchCmd := &cmd.DiceDBCmd{ + Cmd: diceDBCmd.Cmd, + Args: diceDBCmd.Args, + } + cmdList = append(cmdList, watchCmd) + isWatchNotification = false + } + } + + // Unsubscribe Unwatch command type + if meta.CmdType == Unwatch { + return h.handleCommandUnwatch(cmdList) + } + + // Scatter the broken-down commands to the appropriate shards. + if err := h.scatter(ctx, cmdList, meta.CmdType); err != nil { + return err + } + + // Gather the responses from the shards and write them to the buffer. + if err := h.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel); err != nil { + return err + } + + if meta.CmdType == Watch { + // Proceed to subscribe after successful execution + h.handleCommandWatch(cmdList) + } + + return nil +} + +func (h *BaseCommandHandler) handleCustomCommands(diceDBCmd *cmd.DiceDBCmd) error { + // if command is of type Custom, write a custom logic around it + switch diceDBCmd.Cmd { + case CmdAuth: + h.ioThreadWriteChan <- h.RespAuth(diceDBCmd.Args) + return nil + case CmdEcho: + h.ioThreadWriteChan <- RespEcho(diceDBCmd.Args) + return nil + case CmdAbort: + h.ioThreadWriteChan <- clientio.OK + slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", h.id)) + h.globalErrorChan <- diceerrors.ErrAborted + return nil + case CmdPing: + h.ioThreadWriteChan <- RespPING(diceDBCmd.Args) + return nil + case CmdHello: + h.ioThreadWriteChan <- RespHello(diceDBCmd.Args) + return nil + case CmdSleep: + h.ioThreadWriteChan <- RespSleep(diceDBCmd.Args) + return nil + default: + return diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) + } +} + +// handleCommandWatch sends a watch subscription request to the watch manager. +func (h *BaseCommandHandler) handleCommandWatch(cmdList []*cmd.DiceDBCmd) { + h.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: true, + WatchCmd: cmdList[len(cmdList)-1], + AdhocReqChan: h.adhocReqChan, + } +} + +// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client. +// The response is sent before the unwatch request is processed by the watch manager. +func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) error { + // extract the fingerprint + command := cmdList[len(cmdList)-1] + fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32) + if parseErr != nil { + h.ioThreadWriteChan <- diceerrors.ErrInvalidFingerprint + return nil + } + + // send the unsubscribe request + h.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: false, + AdhocReqChan: h.adhocReqChan, + Fingerprint: uint32(fp), + } + + h.ioThreadWriteChan <- clientio.OK + return nil +} + +// scatter distributes the DiceDB commands to the respective shards based on the key. +// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing. +func (h *BaseCommandHandler) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd, cmdType CmdType) error { + // Otherwise check for the shard based on the key using hash + // and send it to the particular shard + // Check if the context has been canceled or expired. + select { + case <-ctx.Done(): + // If the context is canceled, return the error associated with it. + return ctx.Err() + default: + // Proceed with the default case when the context is not canceled. + + if cmdType == AllShard { + // If the command type is for all shards, iterate over all available shards. + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { + // Get the shard ID (i) and its associated request channel. + shardID, responseChan := i, h.shardManager.GetShard(i).ReqChan + + // Send a StoreOp operation to the shard's request channel. + responseChan <- &ops.StoreOp{ + SeqID: i, // Sequence ID for this operation. + RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. + Cmd: cmds[0], // Command to be executed, using the first command in cmds. + CmdHandlerID: h.id, // ID of the current command handler. + ShardID: shardID, // ID of the shard handling this operation. + Client: nil, // Client information (if applicable). + } + } + } else { + // If the command type is specific to certain commands, process them individually. + for i := uint8(0); i < uint8(len(cmds)); i++ { + // Determine the appropriate shard for the current command using a routing key. + shardID, responseChan := h.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i])) + + // Send a StoreOp operation to the shard's request channel. + responseChan <- &ops.StoreOp{ + SeqID: i, // Sequence ID for this operation. + RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. + Cmd: cmds[i], // Command to be executed, using the current command in cmds. + CmdHandlerID: h.id, // ID of the current command handler. + ShardID: shardID, // ID of the shard handling this operation. + Client: nil, // Client information (if applicable). + } + } + } + } + + return nil +} + +// getRoutingKeyFromCommand determines the key used for shard routing +func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string { + if len(diceDBCmd.Args) > 0 { + return diceDBCmd.Args[0] + } + return diceDBCmd.Cmd +} + +// gather collects the responses from multiple shards and writes the results into the provided buffer. +// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). +func (h *BaseCommandHandler) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) error { + // Collect responses from all shards + storeOp, err := h.gatherResponses(ctx, numCmds) + if err != nil { + return err + } + + if len(storeOp) == 0 { + slog.Error("No response from shards", + slog.String("id", h.id), + slog.String("command", diceDBCmd.Cmd)) + return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) + } + + if isWatchNotification { + return h.handleWatchNotification(ctx, diceDBCmd, storeOp[0], watchLabel) + } + + // Process command based on its type + cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd] + if !ok { + return h.handleUnsupportedCommand(ctx, storeOp[0]) + } + + return h.handleCommand(cmdMeta, diceDBCmd, storeOp) +} + +// gatherResponses collects responses from all shards +func (h *BaseCommandHandler) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) { + storeOp := make([]ops.StoreResponse, 0, numCmds) + + for numCmds > 0 { + select { + case <-ctx.Done(): + slog.Error("Timed out waiting for response from shards", + slog.String("id", h.id), + slog.Any("error", ctx.Err())) + return nil, ctx.Err() + + case resp, ok := <-h.responseChan: + if ok { + storeOp = append(storeOp, *resp) + } + numCmds-- + + case sError, ok := <-h.shardManager.ShardErrorChan: + if ok { + slog.Error("Error from shard", + slog.String("id", h.id), + slog.Any("error", sError)) + return nil, sError.Error + } + } + } + + return storeOp, nil +} + +// handleWatchNotification processes watch notification responses +func (h *BaseCommandHandler) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) error { + fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint()) + + // if watch label is not empty, then this is the first response for the watch command + // hence, we will send the watch label as part of the response + firstRespElem := diceDBCmd.Cmd + if watchLabel != "" { + firstRespElem = watchLabel + } + + if resp.EvalResponse.Error != nil { + return h.writeResponse(querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error)) + } + + return h.writeResponse(querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result)) +} + +// handleUnsupportedCommand processes commands not in CommandsMeta +func (h *BaseCommandHandler) handleUnsupportedCommand(ctx context.Context, resp ops.StoreResponse) error { + if resp.EvalResponse.Error != nil { + return h.writeResponse(resp.EvalResponse.Error) + } + return h.writeResponse(resp.EvalResponse.Result) +} + +// handleCommand processes commands based on their type +func (h *BaseCommandHandler) handleCommand(cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error { + var err error + + switch cmdMeta.CmdType { + case SingleShard, Custom: + if storeOp[0].EvalResponse.Error != nil { + err = h.writeResponse(storeOp[0].EvalResponse.Error) + } else { + err = h.writeResponse(storeOp[0].EvalResponse.Result) + } + + if err == nil && h.wl != nil { + h.wl.LogCommand(diceDBCmd) + } + case MultiShard, AllShard: + err = h.writeResponse(cmdMeta.composeResponse(storeOp...)) + + if err == nil && h.wl != nil { + h.wl.LogCommand(diceDBCmd) + } + default: + slog.Error("Unknown command type", + slog.String("id", h.id), + slog.String("command", diceDBCmd.Cmd), + slog.Any("evalResp", storeOp)) + err = h.writeResponse(diceerrors.ErrInternalServer) + } + return err +} + +// writeResponse handles writing responses and logging errors +func (h *BaseCommandHandler) writeResponse(response interface{}) error { + h.ioThreadWriteChan <- response + return nil +} + +func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error { + if diceDBCmd.Cmd != auth.Cmd && !h.Session.IsActive() { + return errors.New("NOAUTH Authentication required") + } + + return nil +} + +func (h *BaseCommandHandler) Stop() error { + slog.Info("Stopping command handler", slog.String("id", h.id)) + h.Session.Expire() + return nil +} + +func GenerateUniqueRequestID() uint32 { + return atomic.AddUint32(&requestCounter, 1) +} diff --git a/internal/commandhandler/manager.go b/internal/commandhandler/manager.go new file mode 100644 index 000000000..a29d9c6a3 --- /dev/null +++ b/internal/commandhandler/manager.go @@ -0,0 +1,70 @@ +package commandhandler + +import ( + "errors" + "sync" + "sync/atomic" + + "github.com/dicedb/dice/internal/shard" +) + +type Manager struct { + activeCmdHandlers sync.Map + numCmdHandlers atomic.Int32 + maxCmdHandlers int32 + ShardManager *shard.ShardManager + mu sync.Mutex +} + +var ( + ErrMaxCmdHandlersReached = errors.New("maximum number of command handlers reached") + ErrCmdHandlerNotFound = errors.New("command handler not found") +) + +func NewManager(maxCmdHandlers int32, sm *shard.ShardManager) *Manager { + return &Manager{ + maxCmdHandlers: maxCmdHandlers, + ShardManager: sm, + } +} + +func (m *Manager) RegisterCommandHandler(cmdHandler *BaseCommandHandler) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.CommandHandlerCount() >= m.maxCmdHandlers { + return ErrMaxCmdHandlersReached + } + + responseChan := cmdHandler.responseChan + preprocessingChan := cmdHandler.preprocessingChan + + if responseChan != nil && preprocessingChan != nil { + m.ShardManager.RegisterCommandHandler(cmdHandler.ID(), responseChan, preprocessingChan) // TODO: Change responseChan type to ShardResponse + } else if responseChan != nil && preprocessingChan == nil { + m.ShardManager.RegisterCommandHandler(cmdHandler.ID(), responseChan, nil) + } + + m.activeCmdHandlers.Store(cmdHandler.ID(), cmdHandler) + + m.numCmdHandlers.Add(1) + return nil +} + +func (m *Manager) CommandHandlerCount() int32 { + return m.numCmdHandlers.Load() +} + +func (m *Manager) UnregisterCommandHandler(id string) error { + m.ShardManager.UnregisterCommandHandler(id) + if client, loaded := m.activeCmdHandlers.LoadAndDelete(id); loaded { + w := client.(BaseCommandHandler) + if err := w.Stop(); err != nil { + return err + } + } else { + return ErrCmdHandlerNotFound + } + m.numCmdHandlers.Add(-1) + return nil +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 306b94d3c..fcdc9b8b7 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -19,7 +19,7 @@ const ( WrongTypeErr = "-WRONGTYPE Operation against a key holding the wrong kind of value" WrongTypeHllErr = "-WRONGTYPE Key is not a valid HyperLogLog string value." InvalidHllErr = "-INVALIDOBJ Corrupted HLL object detected" - IOThreadNotFoundErr = "io-thread with ID %s not found" + CmdHandlerNotFoundErr = "command handler with ID %s not found" JSONPathValueTypeErr = "-WRONGTYPE wrong type of path value - expected string but found %s" HashValueNotIntegerErr = "hash value is not an integer" InternalServerError = "-ERR: Internal server error, unable to process command" diff --git a/internal/iothread/iothread.go b/internal/iothread/iothread.go index 6aab5120f..f95707d10 100644 --- a/internal/iothread/iothread.go +++ b/internal/iothread/iothread.go @@ -2,34 +2,15 @@ package iothread import ( "context" - "errors" - "fmt" "log/slog" - "net" - "strconv" - "sync/atomic" - "syscall" "time" - "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" - "github.com/dicedb/dice/internal/clientio" "github.com/dicedb/dice/internal/clientio/iohandler" - "github.com/dicedb/dice/internal/clientio/requestparser" - "github.com/dicedb/dice/internal/cmd" - diceerrors "github.com/dicedb/dice/internal/errors" - "github.com/dicedb/dice/internal/ops" - "github.com/dicedb/dice/internal/querymanager" - "github.com/dicedb/dice/internal/shard" - "github.com/dicedb/dice/internal/wal" - "github.com/dicedb/dice/internal/watchmanager" - "github.com/google/uuid" ) const defaultRequestTimeout = 6 * time.Second -var requestCounter uint32 - // IOThread interface type IOThread interface { ID() string @@ -39,35 +20,21 @@ type IOThread interface { type BaseIOThread struct { IOThread - id string - ioHandler iohandler.IOHandler - parser requestparser.Parser - shardManager *shard.ShardManager - adhocReqChan chan *cmd.DiceDBCmd - Session *auth.Session - globalErrorChan chan error - responseChan chan *ops.StoreResponse - preprocessingChan chan *ops.StoreResponse - cmdWatchSubscriptionChan chan watchmanager.WatchSubscription - wl wal.AbstractWAL + id string + ioHandler iohandler.IOHandler + Session *auth.Session + ioThreadReadChan chan []byte // Channel to send data to the command handler + ioThreadWriteChan chan interface{} // Channel to receive data from the command handler } -func NewIOThread(wid string, responseChan, preprocessingChan chan *ops.StoreResponse, - cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, - ioHandler iohandler.IOHandler, parser requestparser.Parser, - shardManager *shard.ShardManager, gec chan error, wl wal.AbstractWAL) *BaseIOThread { +func NewIOThread(id string, ioHandler iohandler.IOHandler, + ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}) *BaseIOThread { return &BaseIOThread{ - id: wid, - ioHandler: ioHandler, - parser: parser, - shardManager: shardManager, - globalErrorChan: gec, - responseChan: responseChan, - preprocessingChan: preprocessingChan, - Session: auth.NewSession(), - adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), - cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, - wl: wl, + id: id, + ioHandler: ioHandler, + Session: auth.NewSession(), + ioThreadReadChan: ioThreadReadChan, + ioThreadWriteChan: ioThreadWriteChan, } } @@ -76,9 +43,9 @@ func (t *BaseIOThread) ID() string { } func (t *BaseIOThread) Start(ctx context.Context) error { - errChan := make(chan error, 1) - incomingDataChan := make(chan []byte) - readErrChan := make(chan error) + // local channels to communicate between Start and startInputReader goroutine + incomingDataChan := make(chan []byte) // data channel + readErrChan := make(chan error) // error channel runCtx, runCancel := context.WithCancel(ctx) defer runCancel() @@ -94,17 +61,16 @@ func (t *BaseIOThread) Start(ctx context.Context) error { slog.Warn("Error stopping io-thread:", slog.String("id", t.id), slog.Any("error", err)) } return ctx.Err() - case err := <-errChan: - return t.handleError(err) - case cmdReq := <-t.adhocReqChan: - t.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout) case data := <-incomingDataChan: - if err := t.processIncomingData(ctx, &data, errChan); err != nil { - return err - } + t.ioThreadReadChan <- data case err := <-readErrChan: slog.Debug("Read error in io-thread, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) return err + case resp := <-t.ioThreadWriteChan: + err := t.ioHandler.Write(ctx, resp) + if err != nil { + slog.Debug("Error sending response to client", slog.String("id", t.id), slog.Any("error", err)) + } } } } @@ -132,480 +98,8 @@ func (t *BaseIOThread) startInputReader(ctx context.Context, incomingDataChan ch } } -func (t *BaseIOThread) handleError(err error) error { - if err != nil { - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { - slog.Debug("Connection closed for io-thread", slog.String("id", t.id), slog.Any("error", err)) - return err - } - } - return fmt.Errorf("error writing response: %v", err) -} - -func (t *BaseIOThread) processIncomingData(ctx context.Context, data *[]byte, errChan chan error) error { - commands, err := t.parser.Parse(*data) - - if err != nil { - err = t.ioHandler.Write(ctx, err) - if err != nil { - slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) - return err - } - return nil - } - - if len(commands) == 0 { - err = t.ioHandler.Write(ctx, fmt.Errorf("ERR: Invalid request")) - if err != nil { - slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) - return err - } - return nil - } - - // DiceDB supports clients to send only one request at a time - // We also need to ensure that the client is blocked until the response is received - if len(commands) > 1 { - err = t.ioHandler.Write(ctx, fmt.Errorf("ERR: Multiple commands not supported")) - if err != nil { - slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) - return err - } - } - - err = t.isAuthenticated(commands[0]) - if err != nil { - writeErr := t.ioHandler.Write(ctx, err) - if writeErr != nil { - slog.Debug("Write error, connection closed possibly", slog.Any("error", errors.Join(err, writeErr))) - return errors.Join(err, writeErr) - } - return nil - } - - t.handleCmdRequestWithTimeout(ctx, errChan, commands, false, defaultRequestTimeout) - return nil -} - -func (t *BaseIOThread) handleCmdRequestWithTimeout(ctx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) { - execCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - t.executeCommandHandler(execCtx, errChan, commands, isWatchNotification) -} - -func (t *BaseIOThread) executeCommandHandler(execCtx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) { - // Retrieve metadata for the command to determine if multisharding is supported. - meta, ok := CommandsMeta[commands[0].Cmd] - if ok && meta.preProcessing { - if err := meta.preProcessResponse(t, commands[0]); err != nil { - e := t.ioHandler.Write(execCtx, err) - if e != nil { - slog.Debug("Error executing for io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - } - } - - err := t.executeCommand(execCtx, commands[0], isWatchNotification) - if err != nil { - slog.Error("Error executing command", slog.String("id", t.id), slog.Any("error", err)) - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { - slog.Debug("Connection closed for io-thread", slog.String("id", t.id), slog.Any("error", err)) - errChan <- err - } - } -} - -func (t *BaseIOThread) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error { - // Break down the single command into multiple commands if multisharding is supported. - // The length of cmdList helps determine how many shards to wait for responses. - cmdList := make([]*cmd.DiceDBCmd, 0) - var watchLabel string - - // Retrieve metadata for the command to determine if multisharding is supported. - meta, ok := CommandsMeta[diceDBCmd.Cmd] - if !ok { - // If no metadata exists, treat it as a single command and not migrated - cmdList = append(cmdList, diceDBCmd) - } else { - // Depending on the command type, decide how to handle it. - switch meta.CmdType { - case Global: - // If it's a global command, process it immediately without involving any shards. - err := t.ioHandler.Write(ctx, meta.IOThreadHandler(diceDBCmd.Args)) - slog.Debug("Error executing command on io-thread", slog.String("id", t.id), slog.Any("error", err)) - return err - - case SingleShard: - // For single-shard or custom commands, process them without breaking up. - cmdList = append(cmdList, diceDBCmd) - - case MultiShard, AllShard: - var err error - // If the command supports multisharding, break it down into multiple commands. - cmdList, err = meta.decomposeCommand(ctx, t, diceDBCmd) - if err != nil { - var ioErr error - // Check if it's a CustomError - var customErr *diceerrors.PreProcessError - if errors.As(err, &customErr) { - ioErr = t.ioHandler.Write(ctx, customErr.Result) - } else { - ioErr = t.ioHandler.Write(ctx, err) - } - if ioErr != nil { - slog.Debug("Error executing for io-thread", slog.String("id", t.id), slog.Any("error", ioErr)) - } - return ioErr - } - - case Custom: - return t.handleCustomCommands(ctx, diceDBCmd) - - case Watch: - // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass - // it along as is. - // Modify the command name to remove the .WATCH suffix, this will allow us to generate a consistent - // fingerprint (which uses the command name without the suffix) - diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6] - - // check if the last argument is a watch label - label := diceDBCmd.Args[len(diceDBCmd.Args)-1] - if _, err := uuid.Parse(label); err == nil { - watchLabel = label - - // remove the watch label from the args - diceDBCmd.Args = diceDBCmd.Args[:len(diceDBCmd.Args)-1] - } - - watchCmd := &cmd.DiceDBCmd{ - Cmd: diceDBCmd.Cmd, - Args: diceDBCmd.Args, - } - cmdList = append(cmdList, watchCmd) - isWatchNotification = true - - case Unwatch: - // Generate the Cmd being unwatched. All we need to do is remove the .UNWATCH suffix from the command and pass - // it along as is. - // Modify the command name to remove the .UNWATCH suffix, this will allow us to generate a consistent - // fingerprint (which uses the command name without the suffix) - diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-8] - watchCmd := &cmd.DiceDBCmd{ - Cmd: diceDBCmd.Cmd, - Args: diceDBCmd.Args, - } - cmdList = append(cmdList, watchCmd) - isWatchNotification = false - } - } - - // Unsubscribe Unwatch command type - if meta.CmdType == Unwatch { - return t.handleCommandUnwatch(ctx, cmdList) - } - - // Scatter the broken-down commands to the appropriate shards. - if err := t.scatter(ctx, cmdList, meta.CmdType); err != nil { - return err - } - - // Gather the responses from the shards and write them to the buffer. - if err := t.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel); err != nil { - return err - } - - if meta.CmdType == Watch { - // Proceed to subscribe after successful execution - t.handleCommandWatch(cmdList) - } - - return nil -} - -func (t *BaseIOThread) handleCustomCommands(ctx context.Context, diceDBCmd *cmd.DiceDBCmd) error { - // if command is of type Custom, write a custom logic around it - switch diceDBCmd.Cmd { - case CmdAuth: - err := t.ioHandler.Write(ctx, t.RespAuth(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending auth response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdEcho: - err := t.ioHandler.Write(ctx, RespEcho(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending echo response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdAbort: - err := t.ioHandler.Write(ctx, clientio.OK) - if err != nil { - slog.Error("Error sending abort response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", t.id)) - t.globalErrorChan <- diceerrors.ErrAborted - return err - case CmdPing: - err := t.ioHandler.Write(ctx, RespPING(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdHello: - err := t.ioHandler.Write(ctx, RespHello(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdSleep: - err := t.ioHandler.Write(ctx, RespSleep(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - default: - return diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) - } -} - -// handleCommandWatch sends a watch subscription request to the watch manager. -func (t *BaseIOThread) handleCommandWatch(cmdList []*cmd.DiceDBCmd) { - t.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ - Subscribe: true, - WatchCmd: cmdList[len(cmdList)-1], - AdhocReqChan: t.adhocReqChan, - } -} - -// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client. -// The response is sent before the unwatch request is processed by the watch manager. -func (t *BaseIOThread) handleCommandUnwatch(ctx context.Context, cmdList []*cmd.DiceDBCmd) error { - // extract the fingerprint - command := cmdList[len(cmdList)-1] - fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32) - if parseErr != nil { - err := t.ioHandler.Write(ctx, diceerrors.ErrInvalidFingerprint) - if err != nil { - return fmt.Errorf("error sending push response to client: %v", err) - } - return parseErr - } - - // send the unsubscribe request - t.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ - Subscribe: false, - AdhocReqChan: t.adhocReqChan, - Fingerprint: uint32(fp), - } - - err := t.ioHandler.Write(ctx, clientio.RespOK) - if err != nil { - return fmt.Errorf("error sending push response to client: %v", err) - } - return nil -} - -// scatter distributes the DiceDB commands to the respective shards based on the key. -// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing. -func (t *BaseIOThread) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd, cmdType CmdType) error { - // Otherwise check for the shard based on the key using hash - // and send it to the particular shard - // Check if the context has been canceled or expired. - select { - case <-ctx.Done(): - // If the context is canceled, return the error associated with it. - return ctx.Err() - default: - // Proceed with the default case when the context is not canceled. - - if cmdType == AllShard { - // If the command type is for all shards, iterate over all available shards. - for i := uint8(0); i < uint8(t.shardManager.GetShardCount()); i++ { - // Get the shard ID (i) and its associated request channel. - shardID, responseChan := i, t.shardManager.GetShard(i).ReqChan - - // Send a StoreOp operation to the shard's request channel. - responseChan <- &ops.StoreOp{ - SeqID: i, // Sequence ID for this operation. - RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. - Cmd: cmds[0], // Command to be executed, using the first command in cmds. - IOThreadID: t.id, // ID of the current io-thread. - ShardID: shardID, // ID of the shard handling this operation. - Client: nil, // Client information (if applicable). - } - } - } else { - // If the command type is specific to certain commands, process them individually. - for i := uint8(0); i < uint8(len(cmds)); i++ { - // Determine the appropriate shard for the current command using a routing key. - shardID, responseChan := t.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i])) - - // Send a StoreOp operation to the shard's request channel. - responseChan <- &ops.StoreOp{ - SeqID: i, // Sequence ID for this operation. - RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. - Cmd: cmds[i], // Command to be executed, using the current command in cmds. - IOThreadID: t.id, // ID of the current io-thread. - ShardID: shardID, // ID of the shard handling this operation. - Client: nil, // Client information (if applicable). - } - } - } - } - - return nil -} - -// getRoutingKeyFromCommand determines the key used for shard routing -func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string { - if len(diceDBCmd.Args) > 0 { - return diceDBCmd.Args[0] - } - return diceDBCmd.Cmd -} - -// gather collects the responses from multiple shards and writes the results into the provided buffer. -// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (t *BaseIOThread) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) error { - // Collect responses from all shards - storeOp, err := t.gatherResponses(ctx, numCmds) - if err != nil { - return err - } - - if len(storeOp) == 0 { - slog.Error("No response from shards", - slog.String("id", t.id), - slog.String("command", diceDBCmd.Cmd)) - return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) - } - - if isWatchNotification { - return t.handleWatchNotification(ctx, diceDBCmd, storeOp[0], watchLabel) - } - - // Process command based on its type - cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd] - if !ok { - return t.handleUnsupportedCommand(ctx, storeOp[0]) - } - - return t.handleCommand(ctx, cmdMeta, diceDBCmd, storeOp) -} - -// gatherResponses collects responses from all shards -func (t *BaseIOThread) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) { - storeOp := make([]ops.StoreResponse, 0, numCmds) - - for numCmds > 0 { - select { - case <-ctx.Done(): - slog.Error("Timed out waiting for response from shards", - slog.String("id", t.id), - slog.Any("error", ctx.Err())) - return nil, ctx.Err() - - case resp, ok := <-t.responseChan: - if ok { - storeOp = append(storeOp, *resp) - } - numCmds-- - - case sError, ok := <-t.shardManager.ShardErrorChan: - if ok { - slog.Error("Error from shard", - slog.String("id", t.id), - slog.Any("error", sError)) - return nil, sError.Error - } - } - } - - return storeOp, nil -} - -// handleWatchNotification processes watch notification responses -func (t *BaseIOThread) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) error { - fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint()) - - // if watch label is not empty, then this is the first response for the watch command - // hence, we will send the watch label as part of the response - firstRespElem := diceDBCmd.Cmd - if watchLabel != "" { - firstRespElem = watchLabel - } - - if resp.EvalResponse.Error != nil { - return t.writeResponse(ctx, querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error)) - } - - return t.writeResponse(ctx, querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result)) -} - -// handleUnsupportedCommand processes commands not in CommandsMeta -func (t *BaseIOThread) handleUnsupportedCommand(ctx context.Context, resp ops.StoreResponse) error { - if resp.EvalResponse.Error != nil { - return t.writeResponse(ctx, resp.EvalResponse.Error) - } - return t.writeResponse(ctx, resp.EvalResponse.Result) -} - -// handleCommand processes commands based on their type -func (t *BaseIOThread) handleCommand(ctx context.Context, cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error { - var err error - - switch cmdMeta.CmdType { - case SingleShard, Custom: - if storeOp[0].EvalResponse.Error != nil { - err = t.writeResponse(ctx, storeOp[0].EvalResponse.Error) - } else { - err = t.writeResponse(ctx, storeOp[0].EvalResponse.Result) - } - - if err == nil && t.wl != nil { - t.wl.LogCommand(diceDBCmd) - } - case MultiShard, AllShard: - err = t.writeResponse(ctx, cmdMeta.composeResponse(storeOp...)) - - if err == nil && t.wl != nil { - t.wl.LogCommand(diceDBCmd) - } - default: - slog.Error("Unknown command type", - slog.String("id", t.id), - slog.String("command", diceDBCmd.Cmd), - slog.Any("evalResp", storeOp)) - err = t.writeResponse(ctx, diceerrors.ErrInternalServer) - } - return err -} - -// writeResponse handles writing responses and logging errors -func (t *BaseIOThread) writeResponse(ctx context.Context, response interface{}) error { - err := t.ioHandler.Write(ctx, response) - if err != nil { - slog.Debug("Error sending response to client", - slog.String("id", t.id), - slog.Any("error", err)) - } - return err -} - -func (t *BaseIOThread) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error { - if diceDBCmd.Cmd != auth.Cmd && !t.Session.IsActive() { - return errors.New("NOAUTH Authentication required") - } - - return nil -} - func (t *BaseIOThread) Stop() error { slog.Info("Stopping io-thread", slog.String("id", t.id)) t.Session.Expire() return nil } - -func GenerateUniqueRequestID() uint32 { - return atomic.AddUint32(&requestCounter, 1) -} diff --git a/internal/iothread/manager.go b/internal/iothread/manager.go index 573117929..465996837 100644 --- a/internal/iothread/manager.go +++ b/internal/iothread/manager.go @@ -4,15 +4,12 @@ import ( "errors" "sync" "sync/atomic" - - "github.com/dicedb/dice/internal/shard" ) type Manager struct { connectedClients sync.Map numIOThreads atomic.Int32 maxClients int32 - shardManager *shard.ShardManager mu sync.Mutex } @@ -21,10 +18,9 @@ var ( ErrIOThreadNotFound = errors.New("io-thread not found") ) -func NewManager(maxClients int32, sm *shard.ShardManager) *Manager { +func NewManager(maxClients int32) *Manager { return &Manager{ - maxClients: maxClients, - shardManager: sm, + maxClients: maxClients, } } @@ -37,14 +33,6 @@ func (m *Manager) RegisterIOThread(ioThread IOThread) error { } m.connectedClients.Store(ioThread.ID(), ioThread) - responseChan := ioThread.(*BaseIOThread).responseChan - preprocessingChan := ioThread.(*BaseIOThread).preprocessingChan - - if responseChan != nil && preprocessingChan != nil { - m.shardManager.RegisterIOThread(ioThread.ID(), responseChan, preprocessingChan) // TODO: Change responseChan type to ShardResponse - } else if responseChan != nil && preprocessingChan == nil { - m.shardManager.RegisterIOThread(ioThread.ID(), responseChan, nil) - } m.numIOThreads.Add(1) return nil @@ -72,8 +60,6 @@ func (m *Manager) UnregisterIOThread(id string) error { return ErrIOThreadNotFound } - m.shardManager.UnregisterIOThread(id) m.numIOThreads.Add(-1) - return nil } diff --git a/internal/ops/store_op.go b/internal/ops/store_op.go index b151066e9..0ca912137 100644 --- a/internal/ops/store_op.go +++ b/internal/ops/store_op.go @@ -11,8 +11,8 @@ type StoreOp struct { RequestID uint32 // RequestID identifies the request that this StoreOp belongs to Cmd *cmd.DiceDBCmd // Cmd is the atomic Store command (e.g., GET, SET) ShardID uint8 // ShardID of the shard on which the Store command will be executed - IOThreadID string // IOThreadID is the ID of the io-thread that sent this Store operation - Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the IOThreadID in the future + CmdHandlerID string // CmdHandlerID is the ID of the command handler that sent this Store operation + Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the CmdHandlerID in the future HTTPOp bool // HTTPOp is true if this Store operation is an HTTP operation WebsocketOp bool // WebsocketOp is true if this Store operation is a Websocket operation PreProcessing bool // PreProcessing indicates whether a comamnd operation requires preprocessing before execution. This is mainly used is multi-step-multi-shard commands diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index d472dbc62..f02ac8655 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -27,8 +27,9 @@ import ( ) const ( - Abort = "ABORT" - stringNil = "(nil)" + Abort = "ABORT" + stringNil = "(nil)" + httpCmdHandlerID = "httpServer" ) var unimplementedCommands = map[string]bool{ @@ -96,7 +97,7 @@ func (s *HTTPServer) Run(ctx context.Context) error { httpCtx, cancelHTTP := context.WithCancel(ctx) defer cancelHTTP() - s.shardManager.RegisterIOThread("httpServer", s.ioChan, nil) + s.shardManager.RegisterCommandHandler(httpCmdHandlerID, s.ioChan, nil) wg.Add(1) go func() { @@ -167,10 +168,10 @@ func (s *HTTPServer) DiceHTTPHandler(writer http.ResponseWriter, request *http.R // send request to Shard Manager s.shardManager.GetShard(0).ReqChan <- &ops.StoreOp{ - Cmd: diceDBCmd, - IOThreadID: "httpServer", - ShardID: 0, - HTTPOp: true, + Cmd: diceDBCmd, + CmdHandlerID: httpCmdHandlerID, + ShardID: 0, + HTTPOp: true, } // Wait for response @@ -218,11 +219,11 @@ func (s *HTTPServer) DiceHTTPQwatchHandler(writer http.ResponseWriter, request * qwatchClient := comm.NewHTTPQwatchClient(s.qwatchResponseChan, clientIdentifierID) // Prepare the store operation storeOp := &ops.StoreOp{ - Cmd: diceDBCmd, - IOThreadID: "httpServer", - ShardID: 0, - Client: qwatchClient, - HTTPOp: true, + Cmd: diceDBCmd, + CmdHandlerID: httpCmdHandlerID, + ShardID: 0, + Client: qwatchClient, + HTTPOp: true, } slog.Info("Registered client for watching query", slog.Any("clientID", clientIdentifierID), diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index 51e205979..949e578ef 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -11,6 +11,8 @@ import ( "syscall" "time" + "github.com/dicedb/dice/internal/commandhandler" + "github.com/dicedb/dice/internal/ops" "github.com/dicedb/dice/internal/server/abstractserver" "github.com/dicedb/dice/internal/wal" @@ -21,13 +23,13 @@ import ( "github.com/dicedb/dice/internal/clientio/iohandler/netconn" respparser "github.com/dicedb/dice/internal/clientio/requestparser/resp" "github.com/dicedb/dice/internal/iothread" - "github.com/dicedb/dice/internal/ops" "github.com/dicedb/dice/internal/shard" ) var ( - ioThreadCounter uint64 - startTime = time.Now().UnixNano() / int64(time.Millisecond) + ioThreadCounter uint64 + cmdHandlerCounter uint64 + startTime = time.Now().UnixNano() / int64(time.Millisecond) ) var ( @@ -45,6 +47,7 @@ type Server struct { serverFD int connBacklogSize int ioThreadManager *iothread.Manager + cmdHandlerManager *commandhandler.Manager shardManager *shard.ShardManager watchManager *watchmanager.Manager cmdWatchSubscriptionChan chan watchmanager.WatchSubscription @@ -52,13 +55,15 @@ type Server struct { wl wal.AbstractWAL } -func NewServer(shardManager *shard.ShardManager, ioThreadManager *iothread.Manager, - cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, wl wal.AbstractWAL) *Server { +func NewServer(shardManager *shard.ShardManager, ioThreadManager *iothread.Manager, cmdHandlerManager *commandhandler.Manager, + cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, + globalErrChan chan error, wl wal.AbstractWAL) *Server { return &Server{ Host: config.DiceConfig.RespServer.Addr, Port: config.DiceConfig.RespServer.Port, connBacklogSize: DefaultConnBacklogSize, ioThreadManager: ioThreadManager, + cmdHandlerManager: cmdHandlerManager, shardManager: shardManager, watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan, cmdWatchChan), cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, @@ -191,13 +196,21 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou return err } - parser := respparser.NewParser() + // create a new io-thread + ioThreadID := GenerateUniqueIOThreadID() + ioThreadReadChan := make(chan []byte) // for sending data to the command handler from the io-thread + ioThreadWriteChan := make(chan interface{}) // for sending data to the io-thread from the command handler + thread := iothread.NewIOThread(ioThreadID, ioHandler, ioThreadReadChan, ioThreadWriteChan) + // For each io-thread, we create a dedicated command handler - 1:1 mapping + cmdHandlerID := GenerateUniqueCommandHandlerID() + parser := respparser.NewParser() responseChan := make(chan *ops.StoreResponse) // responseChan is used for handling common responses from shards preprocessingChan := make(chan *ops.StoreResponse) // preprocessingChan is specifically for handling responses from shards for commands that require preprocessing - ioThreadID := GenerateUniqueIOThreadID() - thread := iothread.NewIOThread(ioThreadID, responseChan, preprocessingChan, s.cmdWatchSubscriptionChan, ioHandler, parser, s.shardManager, s.globalErrorChan, s.wl) + handler := commandhandler.NewCommandHandler(cmdHandlerID, responseChan, preprocessingChan, + s.cmdWatchSubscriptionChan, parser, s.shardManager, + s.globalErrorChan, ioThreadReadChan, ioThreadWriteChan, s.wl) // Register the io-thread with the manager err = s.ioThreadManager.RegisterIOThread(thread) @@ -207,6 +220,15 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou wg.Add(1) go s.startIOThread(ctx, wg, thread) + + // Register the command handler with the manager + err = s.cmdHandlerManager.RegisterCommandHandler(handler) + if err != nil { + return err + } + + wg.Add(1) + go s.startCommandHandler(ctx, wg, handler) } } } @@ -227,12 +249,34 @@ func (s *Server) startIOThread(ctx context.Context, wg *sync.WaitGroup, thread * } } +func (s *Server) startCommandHandler(ctx context.Context, wg *sync.WaitGroup, cmdHandler *commandhandler.BaseCommandHandler) { + wg.Done() + defer func(wm *commandhandler.Manager, id string) { + err := wm.UnregisterCommandHandler(id) + if err != nil { + slog.Warn("Failed to unregister command handler", slog.String("id", id), slog.Any("error", err)) + } + }(s.cmdHandlerManager, cmdHandler.ID()) + ctx2, cancel := context.WithCancel(ctx) + defer cancel() + err := cmdHandler.Start(ctx2) + if err != nil { + slog.Debug("CommandHandler stopped", slog.String("id", cmdHandler.ID()), slog.Any("error", err)) + } +} + func GenerateUniqueIOThreadID() string { count := atomic.AddUint64(&ioThreadCounter, 1) timestamp := time.Now().UnixNano()/int64(time.Millisecond) - startTime return fmt.Sprintf("W-%d-%d", timestamp, count) } +func GenerateUniqueCommandHandlerID() string { + count := atomic.AddUint64(&cmdHandlerCounter, 1) + timestamp := time.Now().UnixNano()/int64(time.Millisecond) - startTime + return fmt.Sprintf("W-%d-%d", timestamp, count) +} + func (s *Server) Shutdown() { // Not implemented } diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index 93148614e..391358f3f 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -29,9 +29,12 @@ import ( "golang.org/x/exp/rand" ) -const Qwatch = "Q.WATCH" -const Qunwatch = "Q.UNWATCH" -const Subscribe = "SUBSCRIBE" +const ( + Qwatch = "Q.WATCH" + Qunwatch = "Q.UNWATCH" + Subscribe = "SUBSCRIBE" + wsCmdHandlerID = "wsServer" +) var unimplementedCommandsWebsocket = map[string]bool{ Qunwatch: true, @@ -79,7 +82,7 @@ func (s *WebsocketServer) Run(ctx context.Context) error { websocketCtx, cancelWebsocket := context.WithCancel(ctx) defer cancelWebsocket() - s.shardManager.RegisterIOThread("wsServer", s.ioChan, nil) + s.shardManager.RegisterCommandHandler(wsCmdHandlerID, s.ioChan, nil) wg.Add(1) go func() { @@ -168,10 +171,10 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques // create request sp := &ops.StoreOp{ - Cmd: diceDBCmd, - IOThreadID: "wsServer", - ShardID: 0, - WebsocketOp: true, + Cmd: diceDBCmd, + CmdHandlerID: wsCmdHandlerID, + ShardID: 0, + WebsocketOp: true, } // handle q.watch commands diff --git a/internal/shard/shard_manager.go b/internal/shard/shard_manager.go index a3835b416..af1a90bab 100644 --- a/internal/shard/shard_manager.go +++ b/internal/shard/shard_manager.go @@ -104,15 +104,16 @@ func (manager *ShardManager) GetShard(id ShardID) *ShardThread { return nil } -// RegisterIOThread registers a io-thread with all Shards present in the ShardManager. -func (manager *ShardManager) RegisterIOThread(id string, request, processing chan *ops.StoreResponse) { +// RegisterCommandHandler registers a command handler with all Shards present in the ShardManager. +func (manager *ShardManager) RegisterCommandHandler(id string, request, processing chan *ops.StoreResponse) { for _, shard := range manager.shards { - shard.registerIOThread(id, request, processing) + shard.registerCommandHandler(id, request, processing) } } -func (manager *ShardManager) UnregisterIOThread(id string) { +// UnregisterCommandHandler unregisters a command handler from all Shards present in the ShardManager. +func (manager *ShardManager) UnregisterCommandHandler(id string) { for _, shard := range manager.shards { - shard.unregisterIOThread(id) + shard.unregisterCommandHandler(id) } } diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go index 313e3e4d5..4fc40bf5f 100644 --- a/internal/shard/shard_thread.go +++ b/internal/shard/shard_thread.go @@ -21,23 +21,23 @@ type ShardError struct { Error error // Error is the error that occurred } -// IOChannels holds the communication channels for an io-thread. -// It contains both the common response channel and the preprocessing response channel. -type IOChannels struct { - CommonResponseChan chan *ops.StoreResponse // CommonResponseChan is used to send standard responses for io-thread operations. +// CmdHandlerChannels holds the communication channels for a Command Handler. +// It contains both the response channel and the preprocessing response channel. +type CmdHandlerChannels struct { + ResponseChan chan *ops.StoreResponse // ResponseChan is used to send standard responses for Command Handler operations. PreProcessingResponseChan chan *ops.StoreResponse // PreProcessingResponseChan is used to send responses related to preprocessing operations. } type ShardThread struct { - id ShardID // id is the unique identifier for the shard. - store *dstore.Store // store that the shard is responsible for. - ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests. - ioThreadMap map[string]IOChannels // ioThreadMap maps each io-thread id to its corresponding IOChannels, containing both the common and preprocessing response channels. - mu sync.RWMutex // mu is the ioThreadMap's mutex for thread safety. - globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors. - shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors. - lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks. - cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks. + id ShardID // id is the unique identifier for the shard. + store *dstore.Store // store that the shard is responsible for. + ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests. + cmdHandlerMap map[string]CmdHandlerChannels // cmdHandlerMap maps each command handler id to its corresponding CommandHandlerChannels, containing both the common and preprocessing response channels. + mu sync.RWMutex // mu is the cmdHandlerMap's mutex for thread safety. + globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors. + shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors. + lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks. + cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks. } // NewShardThread creates a new ShardThread instance with the given shard id and error channel. @@ -47,7 +47,7 @@ func NewShardThread(id ShardID, gec chan error, sec chan *ShardError, queryWatch id: id, store: dstore.NewStore(queryWatchChan, cmdWatchChan, evictionStrategy), ReqChan: make(chan *ops.StoreOp, 1000), - ioThreadMap: make(map[string]IOChannels), + cmdHandlerMap: make(map[string]CmdHandlerChannels), globalErrorChan: gec, shardErrorChan: sec, lastCronExecTime: utils.GetCurrentTime(), @@ -79,30 +79,30 @@ func (shard *ShardThread) runCronTasks() { shard.lastCronExecTime = utils.GetCurrentTime() } -func (shard *ShardThread) registerIOThread(id string, responseChan, preprocessingChan chan *ops.StoreResponse) { +func (shard *ShardThread) registerCommandHandler(id string, responseChan, preprocessingChan chan *ops.StoreResponse) { shard.mu.Lock() - shard.ioThreadMap[id] = IOChannels{ - CommonResponseChan: responseChan, + shard.cmdHandlerMap[id] = CmdHandlerChannels{ + ResponseChan: responseChan, PreProcessingResponseChan: preprocessingChan, } shard.mu.Unlock() } -func (shard *ShardThread) unregisterIOThread(id string) { +func (shard *ShardThread) unregisterCommandHandler(id string) { shard.mu.Lock() - delete(shard.ioThreadMap, id) + delete(shard.cmdHandlerMap, id) shard.mu.Unlock() } // processRequest processes a Store operation for the shard. func (shard *ShardThread) processRequest(op *ops.StoreOp) { shard.mu.RLock() - ioChannels, ok := shard.ioThreadMap[op.IOThreadID] + channels, ok := shard.cmdHandlerMap[op.CmdHandlerID] shard.mu.RUnlock() - ioThreadChan := ioChannels.CommonResponseChan - preProcessChan := ioChannels.PreProcessingResponseChan + cmdHandlerChan := channels.ResponseChan + preProcessChan := channels.PreProcessingResponseChan sp := &ops.StoreResponse{ RequestID: op.RequestID, @@ -124,11 +124,11 @@ func (shard *ShardThread) processRequest(op *ops.StoreOp) { } else { shard.shardErrorChan <- &ShardError{ ShardID: shard.id, - Error: fmt.Errorf(diceerrors.IOThreadNotFoundErr, op.IOThreadID), + Error: fmt.Errorf(diceerrors.CmdHandlerNotFoundErr, op.CmdHandlerID), } } - ioThreadChan <- sp + cmdHandlerChan <- sp } // cleanup handles cleanup logic when the shard stops. diff --git a/main.go b/main.go index 09744ae52..833030157 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "time" "github.com/dicedb/dice/internal/cli" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/logger" "github.com/dicedb/dice/internal/server/abstractserver" "github.com/dicedb/dice/internal/wal" @@ -134,8 +135,10 @@ func main() { } defer stopProfiling() } - ioThreadManager := iothread.NewManager(config.DiceConfig.Performance.MaxClients, shardManager) - respServer := resp.NewServer(shardManager, ioThreadManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh, wl) + ioThreadManager := iothread.NewManager(config.DiceConfig.Performance.MaxClients) + cmdHandlerManager := commandhandler.NewManager(config.DiceConfig.Performance.MaxCmdHandlers, shardManager) + + respServer := resp.NewServer(shardManager, ioThreadManager, cmdHandlerManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh, wl) serverWg.Add(1) go runServer(ctx, &serverWg, respServer, serverErrCh) From 0ffb45d6850862ef161a326f12325966da91879e Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Wed, 4 Dec 2024 20:52:39 +0530 Subject: [PATCH 2/8] consolidated write to io channel in one place --- internal/commandhandler/commandhandler.go | 158 ++++++++++------------ 1 file changed, 70 insertions(+), 88 deletions(-) diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go index 6313cac2e..623d9c848 100644 --- a/internal/commandhandler/commandhandler.go +++ b/internal/commandhandler/commandhandler.go @@ -88,85 +88,76 @@ func (h *BaseCommandHandler) Start(ctx context.Context) error { case err := <-errChan: return h.handleError(err) case data := <-h.ioThreadReadChan: - if err := h.processCommand(ctx, &data, h.globalErrorChan); err != nil { + resp, err := h.processCommand(ctx, &data, h.globalErrorChan) + if err != nil { + h.sendResponseToIOThread(err) return err } + h.sendResponseToIOThread(resp) } } } // processCommand processes commands recevied from io thread -func (h *BaseCommandHandler) processCommand(ctx context.Context, data *[]byte, errChan chan error) error { +func (h *BaseCommandHandler) processCommand(ctx context.Context, data *[]byte, gec chan error) (interface{}, error) { commands, err := h.parser.Parse(*data) if err != nil { slog.Debug("error parsing commands from io thread", slog.String("id", h.id), slog.Any("error", err)) - h.ioThreadWriteChan <- err - return nil + return nil, err } if len(commands) == 0 { slog.Debug("invalid request from io thread with zero length", slog.String("id", h.id)) - h.ioThreadWriteChan <- fmt.Errorf("ERR: Invalid request") - return nil + return nil, fmt.Errorf("ERR: Invalid request") } // DiceDB supports clients to send only one request at a time // We also need to ensure that the client is blocked until the response is received if len(commands) > 1 { - h.ioThreadWriteChan <- fmt.Errorf("ERR: Multiple commands not supported") - return nil + return nil, fmt.Errorf("ERR: Multiple commands not supported") } err = h.isAuthenticated(commands[0]) if err != nil { slog.Debug("command handler authentication failed", slog.String("id", h.id), slog.Any("error", err)) - h.ioThreadWriteChan <- err - return nil + return nil, err } - h.handleCmdRequestWithTimeout(ctx, errChan, commands, false, defaultRequestTimeout) - return nil + return h.handleCmdRequestWithTimeout(ctx, gec, commands, false, defaultRequestTimeout) } -func (h *BaseCommandHandler) handleCmdRequestWithTimeout(ctx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) { +func (h *BaseCommandHandler) handleCmdRequestWithTimeout(ctx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) (interface{}, error) { execCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - h.executeCommandHandler(execCtx, errChan, commands, isWatchNotification) -} - -func (h *BaseCommandHandler) handleError(err error) error { - if err != nil { - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { - slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) - return err - } - } - return fmt.Errorf("error writing response: %v", err) + return h.executeCommandHandler(execCtx, gec, commands, isWatchNotification) } -func (h *BaseCommandHandler) executeCommandHandler(execCtx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) { +func (h *BaseCommandHandler) executeCommandHandler(execCtx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) { // Retrieve metadata for the command to determine if multisharding is supported. meta, ok := CommandsMeta[commands[0].Cmd] if ok && meta.preProcessing { if err := meta.preProcessResponse(h, commands[0]); err != nil { slog.Debug("error pre processing response", slog.String("id", h.id), slog.Any("error", err)) - h.ioThreadWriteChan <- err + return nil, err } } - err := h.executeCommand(execCtx, commands[0], isWatchNotification) + resp, err := h.executeCommand(execCtx, commands[0], isWatchNotification) + + // log error and send to global error channel if it's a connection error if err != nil { slog.Error("Error executing command", slog.String("id", h.id), slog.Any("error", err)) if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) - errChan <- err + gec <- err } - h.ioThreadWriteChan <- err } + + return resp, err } -func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error { +func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) { // Break down the single command into multiple commands if multisharding is supported. // The length of cmdList helps determine how many shards to wait for responses. cmdList := make([]*cmd.DiceDBCmd, 0) @@ -181,9 +172,8 @@ func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd. // Depending on the command type, decide how to handle it. switch meta.CmdType { case Global: - // If it's a global command, process it immediately without involving any shards. - h.ioThreadWriteChan <- meta.CmdHandlerFunction(diceDBCmd.Args) - return nil + // process global command immediately without involving any shards. + return meta.CmdHandlerFunction(diceDBCmd.Args), nil case SingleShard: // For single-shard or custom commands, process them without breaking up. @@ -198,11 +188,10 @@ func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd. // Check if it's a CustomError var customErr *diceerrors.PreProcessError if errors.As(err, &customErr) { - h.ioThreadWriteChan <- customErr.Result + return nil, fmt.Errorf("%v", customErr.Result) } else { - h.ioThreadWriteChan <- err + return nil, err } - return nil } case Custom: @@ -253,47 +242,42 @@ func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd. // Scatter the broken-down commands to the appropriate shards. if err := h.scatter(ctx, cmdList, meta.CmdType); err != nil { - return err + return nil, err } // Gather the responses from the shards and write them to the buffer. - if err := h.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel); err != nil { - return err + resp, err := h.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel) + if err != nil { + return nil, err } + // Proceed to subscribe after successful execution if meta.CmdType == Watch { - // Proceed to subscribe after successful execution h.handleCommandWatch(cmdList) } - return nil + return resp, nil } -func (h *BaseCommandHandler) handleCustomCommands(diceDBCmd *cmd.DiceDBCmd) error { +func (h *BaseCommandHandler) handleCustomCommands(diceDBCmd *cmd.DiceDBCmd) (interface{}, error) { // if command is of type Custom, write a custom logic around it switch diceDBCmd.Cmd { case CmdAuth: - h.ioThreadWriteChan <- h.RespAuth(diceDBCmd.Args) - return nil + return h.RespAuth(diceDBCmd.Args), nil case CmdEcho: - h.ioThreadWriteChan <- RespEcho(diceDBCmd.Args) - return nil + return RespEcho(diceDBCmd.Args), nil case CmdAbort: - h.ioThreadWriteChan <- clientio.OK slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", h.id)) h.globalErrorChan <- diceerrors.ErrAborted - return nil + return clientio.OK, nil case CmdPing: - h.ioThreadWriteChan <- RespPING(diceDBCmd.Args) - return nil + return RespPING(diceDBCmd.Args), nil case CmdHello: - h.ioThreadWriteChan <- RespHello(diceDBCmd.Args) - return nil + return RespHello(diceDBCmd.Args), nil case CmdSleep: - h.ioThreadWriteChan <- RespSleep(diceDBCmd.Args) - return nil + return RespSleep(diceDBCmd.Args), nil default: - return diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) + return nil, diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) } } @@ -308,13 +292,12 @@ func (h *BaseCommandHandler) handleCommandWatch(cmdList []*cmd.DiceDBCmd) { // handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client. // The response is sent before the unwatch request is processed by the watch manager. -func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) error { +func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) (interface{}, error) { // extract the fingerprint command := cmdList[len(cmdList)-1] fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32) if parseErr != nil { - h.ioThreadWriteChan <- diceerrors.ErrInvalidFingerprint - return nil + return nil, diceerrors.ErrInvalidFingerprint } // send the unsubscribe request @@ -324,8 +307,7 @@ func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) erro Fingerprint: uint32(fp), } - h.ioThreadWriteChan <- clientio.OK - return nil + return clientio.OK, nil } // scatter distributes the DiceDB commands to the respective shards based on the key. @@ -389,28 +371,28 @@ func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string { // gather collects the responses from multiple shards and writes the results into the provided buffer. // It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (h *BaseCommandHandler) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) error { +func (h *BaseCommandHandler) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) (interface{}, error) { // Collect responses from all shards storeOp, err := h.gatherResponses(ctx, numCmds) if err != nil { - return err + return nil, err } if len(storeOp) == 0 { slog.Error("No response from shards", slog.String("id", h.id), slog.String("command", diceDBCmd.Cmd)) - return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) + return nil, fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) } if isWatchNotification { - return h.handleWatchNotification(ctx, diceDBCmd, storeOp[0], watchLabel) + return h.handleWatchNotification(diceDBCmd, storeOp[0], watchLabel) } // Process command based on its type cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd] if !ok { - return h.handleUnsupportedCommand(ctx, storeOp[0]) + return h.handleUnsupportedCommand(storeOp[0]) } return h.handleCommand(cmdMeta, diceDBCmd, storeOp) @@ -448,7 +430,7 @@ func (h *BaseCommandHandler) gatherResponses(ctx context.Context, numCmds int) ( } // handleWatchNotification processes watch notification responses -func (h *BaseCommandHandler) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) error { +func (h *BaseCommandHandler) handleWatchNotification(diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) (interface{}, error) { fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint()) // if watch label is not empty, then this is the first response for the watch command @@ -459,55 +441,55 @@ func (h *BaseCommandHandler) handleWatchNotification(ctx context.Context, diceDB } if resp.EvalResponse.Error != nil { - return h.writeResponse(querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error)) + // This is a special case where error is returned as part of the watch response + return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error), nil } - return h.writeResponse(querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result)) + return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result), nil } // handleUnsupportedCommand processes commands not in CommandsMeta -func (h *BaseCommandHandler) handleUnsupportedCommand(ctx context.Context, resp ops.StoreResponse) error { +func (h *BaseCommandHandler) handleUnsupportedCommand(resp ops.StoreResponse) (interface{}, error) { if resp.EvalResponse.Error != nil { - return h.writeResponse(resp.EvalResponse.Error) + return nil, resp.EvalResponse.Error } - return h.writeResponse(resp.EvalResponse.Result) + return resp.EvalResponse.Result, nil } // handleCommand processes commands based on their type -func (h *BaseCommandHandler) handleCommand(cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error { - var err error - +func (h *BaseCommandHandler) handleCommand(cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) (interface{}, error) { switch cmdMeta.CmdType { case SingleShard, Custom: if storeOp[0].EvalResponse.Error != nil { - err = h.writeResponse(storeOp[0].EvalResponse.Error) + return nil, storeOp[0].EvalResponse.Error } else { - err = h.writeResponse(storeOp[0].EvalResponse.Result) + return storeOp[0].EvalResponse.Result, nil } - if err == nil && h.wl != nil { - h.wl.LogCommand(diceDBCmd) - } case MultiShard, AllShard: - err = h.writeResponse(cmdMeta.composeResponse(storeOp...)) + return cmdMeta.composeResponse(storeOp...), nil - if err == nil && h.wl != nil { - h.wl.LogCommand(diceDBCmd) - } default: slog.Error("Unknown command type", slog.String("id", h.id), slog.String("command", diceDBCmd.Cmd), slog.Any("evalResp", storeOp)) - err = h.writeResponse(diceerrors.ErrInternalServer) + return nil, diceerrors.ErrInternalServer } - return err } -// writeResponse handles writing responses and logging errors -func (h *BaseCommandHandler) writeResponse(response interface{}) error { +func (h *BaseCommandHandler) handleError(err error) error { + if err != nil { + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) + return err + } + } + return fmt.Errorf("error writing response: %v", err) +} + +func (h *BaseCommandHandler) sendResponseToIOThread(response interface{}) { h.ioThreadWriteChan <- response - return nil } func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error { From d54d4bb1ab3ef58c27fb72424339484721d3dd08 Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Wed, 4 Dec 2024 21:00:00 +0530 Subject: [PATCH 3/8] fixed linter warnings --- internal/commandhandler/commandhandler.go | 17 ++++++++++------- internal/iothread/iothread.go | 3 --- internal/shard/shard_thread.go | 19 ++++++++++--------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go index 623d9c848..71813565f 100644 --- a/internal/commandhandler/commandhandler.go +++ b/internal/commandhandler/commandhandler.go @@ -84,16 +84,16 @@ func (h *BaseCommandHandler) Start(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case cmdReq := <-h.adhocReqChan: - h.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout) + resp, err := h.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout) + h.sendResponseToIOThread(resp, err) case err := <-errChan: return h.handleError(err) case data := <-h.ioThreadReadChan: resp, err := h.processCommand(ctx, &data, h.globalErrorChan) + h.sendResponseToIOThread(resp, err) if err != nil { - h.sendResponseToIOThread(err) return err } - h.sendResponseToIOThread(resp) } } } @@ -189,9 +189,8 @@ func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd. var customErr *diceerrors.PreProcessError if errors.As(err, &customErr) { return nil, fmt.Errorf("%v", customErr.Result) - } else { - return nil, err } + return nil, err } case Custom: @@ -488,8 +487,12 @@ func (h *BaseCommandHandler) handleError(err error) error { return fmt.Errorf("error writing response: %v", err) } -func (h *BaseCommandHandler) sendResponseToIOThread(response interface{}) { - h.ioThreadWriteChan <- response +func (h *BaseCommandHandler) sendResponseToIOThread(resp interface{}, err error) { + if err != nil { + h.ioThreadWriteChan <- err + return + } + h.ioThreadWriteChan <- resp } func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error { diff --git a/internal/iothread/iothread.go b/internal/iothread/iothread.go index f95707d10..f01023523 100644 --- a/internal/iothread/iothread.go +++ b/internal/iothread/iothread.go @@ -3,14 +3,11 @@ package iothread import ( "context" "log/slog" - "time" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/clientio/iohandler" ) -const defaultRequestTimeout = 6 * time.Second - // IOThread interface type IOThread interface { ID() string diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go index 4fc40bf5f..587573615 100644 --- a/internal/shard/shard_thread.go +++ b/internal/shard/shard_thread.go @@ -29,15 +29,16 @@ type CmdHandlerChannels struct { } type ShardThread struct { - id ShardID // id is the unique identifier for the shard. - store *dstore.Store // store that the shard is responsible for. - ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests. - cmdHandlerMap map[string]CmdHandlerChannels // cmdHandlerMap maps each command handler id to its corresponding CommandHandlerChannels, containing both the common and preprocessing response channels. - mu sync.RWMutex // mu is the cmdHandlerMap's mutex for thread safety. - globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors. - shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors. - lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks. - cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks. + id ShardID // id is the unique identifier for the shard. + store *dstore.Store // store that the shard is responsible for. + ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests. + // cmdHandlerMap maps each command handler id to its corresponding CommandHandlerChannels, containing both the common and preprocessing response channels. + cmdHandlerMap map[string]CmdHandlerChannels + mu sync.RWMutex // mu is the cmdHandlerMap's mutex for thread safety. + globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors. + shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors. + lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks. + cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks. } // NewShardThread creates a new ShardThread instance with the given shard id and error channel. From 9b86bebcaebfbf696474b9a86e0f257ffe58df16 Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Thu, 5 Dec 2024 11:48:02 +0530 Subject: [PATCH 4/8] don't stop command handler at command level error --- internal/commandhandler/commandhandler.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go index 71813565f..050f2059f 100644 --- a/internal/commandhandler/commandhandler.go +++ b/internal/commandhandler/commandhandler.go @@ -91,9 +91,6 @@ func (h *BaseCommandHandler) Start(ctx context.Context) error { case data := <-h.ioThreadReadChan: resp, err := h.processCommand(ctx, &data, h.globalErrorChan) h.sendResponseToIOThread(resp, err) - if err != nil { - return err - } } } } From 690430d2330633686e5a40ceb4d9b0728b2ac371 Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Thu, 5 Dec 2024 12:00:13 +0530 Subject: [PATCH 5/8] fixed type assertion for command handler --- internal/commandhandler/manager.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/commandhandler/manager.go b/internal/commandhandler/manager.go index a29d9c6a3..29b2e5f37 100644 --- a/internal/commandhandler/manager.go +++ b/internal/commandhandler/manager.go @@ -57,9 +57,9 @@ func (m *Manager) CommandHandlerCount() int32 { func (m *Manager) UnregisterCommandHandler(id string) error { m.ShardManager.UnregisterCommandHandler(id) - if client, loaded := m.activeCmdHandlers.LoadAndDelete(id); loaded { - w := client.(BaseCommandHandler) - if err := w.Stop(); err != nil { + if cmdHandler, loaded := m.activeCmdHandlers.LoadAndDelete(id); loaded { + ch := cmdHandler.(*BaseCommandHandler) + if err := ch.Stop(); err != nil { return err } } else { From 3e15eca3882fa0eae736e50418775f2a8be0ac11 Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Thu, 5 Dec 2024 13:04:51 +0530 Subject: [PATCH 6/8] added an error channel between io thread and command handler to signal exit --- internal/commandhandler/commandhandler.go | 6 +++++- internal/iothread/iothread.go | 4 +++- internal/server/resp/server.go | 7 ++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go index 050f2059f..e86eb0154 100644 --- a/internal/commandhandler/commandhandler.go +++ b/internal/commandhandler/commandhandler.go @@ -45,6 +45,7 @@ type BaseCommandHandler struct { globalErrorChan chan error ioThreadReadChan chan []byte // Channel to receive data from io-thread ioThreadWriteChan chan interface{} // Channel to send data to io-thread + ioThreadErrChan chan error // Channel to receive errors from io-thread responseChan chan *ops.StoreResponse // Channel to communicate with shard preprocessingChan chan *ops.StoreResponse // Channel to communicate with shard cmdWatchSubscriptionChan chan watchmanager.WatchSubscription @@ -54,7 +55,7 @@ type BaseCommandHandler struct { func NewCommandHandler(id string, responseChan, preprocessingChan chan *ops.StoreResponse, cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, parser requestparser.Parser, shardManager *shard.ShardManager, gec chan error, - ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, + ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, ioThreadErrChan chan error, wl wal.AbstractWAL) *BaseCommandHandler { return &BaseCommandHandler{ id: id, @@ -65,6 +66,7 @@ func NewCommandHandler(id string, responseChan, preprocessingChan chan *ops.Stor globalErrorChan: gec, ioThreadReadChan: ioThreadReadChan, ioThreadWriteChan: ioThreadWriteChan, + ioThreadErrChan: ioThreadErrChan, responseChan: responseChan, preprocessingChan: preprocessingChan, cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, @@ -83,6 +85,8 @@ func (h *BaseCommandHandler) Start(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() + case err := <-h.ioThreadErrChan: + return err case cmdReq := <-h.adhocReqChan: resp, err := h.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout) h.sendResponseToIOThread(resp, err) diff --git a/internal/iothread/iothread.go b/internal/iothread/iothread.go index f01023523..4149ab1e0 100644 --- a/internal/iothread/iothread.go +++ b/internal/iothread/iothread.go @@ -22,16 +22,18 @@ type BaseIOThread struct { Session *auth.Session ioThreadReadChan chan []byte // Channel to send data to the command handler ioThreadWriteChan chan interface{} // Channel to receive data from the command handler + ioThreadErrChan chan error // Channel to receive errors from the ioHandler } func NewIOThread(id string, ioHandler iohandler.IOHandler, - ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}) *BaseIOThread { + ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, ioThreadErrChan chan error) *BaseIOThread { return &BaseIOThread{ id: id, ioHandler: ioHandler, Session: auth.NewSession(), ioThreadReadChan: ioThreadReadChan, ioThreadWriteChan: ioThreadWriteChan, + ioThreadErrChan: ioThreadErrChan, } } diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index 949e578ef..579d4bf62 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -200,7 +200,8 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou ioThreadID := GenerateUniqueIOThreadID() ioThreadReadChan := make(chan []byte) // for sending data to the command handler from the io-thread ioThreadWriteChan := make(chan interface{}) // for sending data to the io-thread from the command handler - thread := iothread.NewIOThread(ioThreadID, ioHandler, ioThreadReadChan, ioThreadWriteChan) + ioThreadErrChan := make(chan error, 1) // for receiving errors from the io-thread + thread := iothread.NewIOThread(ioThreadID, ioHandler, ioThreadReadChan, ioThreadWriteChan, ioThreadErrChan) // For each io-thread, we create a dedicated command handler - 1:1 mapping cmdHandlerID := GenerateUniqueCommandHandlerID() @@ -209,8 +210,8 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou preprocessingChan := make(chan *ops.StoreResponse) // preprocessingChan is specifically for handling responses from shards for commands that require preprocessing handler := commandhandler.NewCommandHandler(cmdHandlerID, responseChan, preprocessingChan, - s.cmdWatchSubscriptionChan, parser, s.shardManager, - s.globalErrorChan, ioThreadReadChan, ioThreadWriteChan, s.wl) + s.cmdWatchSubscriptionChan, parser, s.shardManager, s.globalErrorChan, + ioThreadReadChan, ioThreadWriteChan, ioThreadErrChan, s.wl) // Register the io-thread with the manager err = s.ioThreadManager.RegisterIOThread(thread) From 07ca2b767eb36e4204ac7177ec49c8dde4dd03b5 Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Thu, 5 Dec 2024 13:10:23 +0530 Subject: [PATCH 7/8] bug fix --- internal/iothread/iothread.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/iothread/iothread.go b/internal/iothread/iothread.go index 4149ab1e0..5309296cf 100644 --- a/internal/iothread/iothread.go +++ b/internal/iothread/iothread.go @@ -64,6 +64,7 @@ func (t *BaseIOThread) Start(ctx context.Context) error { t.ioThreadReadChan <- data case err := <-readErrChan: slog.Debug("Read error in io-thread, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) + t.ioThreadErrChan <- err return err case resp := <-t.ioThreadWriteChan: err := t.ioHandler.Write(ctx, resp) From 4aab95caf756853f64a3bae91539705a3a10b652 Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Mon, 16 Dec 2024 13:25:34 +0530 Subject: [PATCH 8/8] fix build fail --- internal/server/httpws/httpServer.go | 8 ++++---- internal/server/httpws/websocketServer.go | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/server/httpws/httpServer.go b/internal/server/httpws/httpServer.go index e67d0b95e..f4b82da6d 100644 --- a/internal/server/httpws/httpServer.go +++ b/internal/server/httpws/httpServer.go @@ -28,7 +28,7 @@ import ( "sync" "time" - "github.com/dicedb/dice/internal/iothread" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/eval" "github.com/dicedb/dice/internal/server/abstractserver" @@ -158,7 +158,7 @@ func (s *HTTPServer) DiceHTTPHandler(writer http.ResponseWriter, request *http.R return } - if iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.MultiShard { + if commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.MultiShard { writeErrorResponse(writer, http.StatusBadRequest, "unsupported command", "Unsupported command received", slog.String("cmd", diceDBCmd.Cmd)) return @@ -355,9 +355,9 @@ func (s *HTTPServer) writeResponse(writer http.ResponseWriter, result *ops.Store // Check if the command is migrated, if it is we use EvalResponse values // else we use RESPParser to decode the response - _, ok := iothread.CommandsMeta[diceDBCmd.Cmd] + _, ok := commandhandler.CommandsMeta[diceDBCmd.Cmd] // TODO: Remove this conditional check and if (true) condition when all commands are migrated - if !ok || iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.Custom { + if !ok || commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.Custom { responseValue, err = DecodeEvalResponse(result.EvalResponse) if err != nil { slog.Error("Error decoding response", "error", err) diff --git a/internal/server/httpws/websocketServer.go b/internal/server/httpws/websocketServer.go index 66bbe79fb..f1fb8b142 100644 --- a/internal/server/httpws/websocketServer.go +++ b/internal/server/httpws/websocketServer.go @@ -30,7 +30,7 @@ import ( "syscall" "time" - "github.com/dicedb/dice/internal/iothread" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/server/abstractserver" "github.com/dicedb/dice/internal/wal" @@ -173,7 +173,7 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques continue } - if iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.MultiShard { + if commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.MultiShard { if err := WriteResponseWithRetries(conn, []byte("error: unsupported command"), maxRetries); err != nil { slog.Debug(fmt.Sprintf("Error writing message: %v", err)) } @@ -298,7 +298,7 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D var responseValue interface{} // Check if the command is migrated, if it is we use EvalResponse values // else we use RESPParser to decode the response - _, ok := iothread.CommandsMeta[diceDBCmd.Cmd] + _, ok := commandhandler.CommandsMeta[diceDBCmd.Cmd] // TODO: Remove this conditional check and if (true) condition when all commands are migrated if !ok { responseValue, err = DecodeEvalResponse(response.EvalResponse)