diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go index c83aa0143..dffb6d216 100644 --- a/integration_tests/commands/resp/setup.go +++ b/integration_tests/commands/resp/setup.go @@ -29,6 +29,7 @@ import ( "testing" "time" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/iothread" "github.com/dicedb/dice/internal/server/resp" "github.com/dicedb/dice/internal/wal" @@ -211,10 +212,12 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { cmdWatchSubscriptionChan := make(chan watchmanager.WatchSubscription) gec := make(chan error) shardManager := shard.NewShardManager(1, cmdWatchChan, gec) - ioThreadManager := iothread.NewManager(20000, shardManager) + ioThreadManager := iothread.NewManager(20000) + cmdHandlerManager := commandhandler.NewRegistry(20000, shardManager) + // Initialize the RESP Server wl, _ := wal.NewNullWAL() - testServer := resp.NewServer(shardManager, ioThreadManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, wl) + testServer := resp.NewServer(shardManager, ioThreadManager, cmdHandlerManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, wl) ctx, cancel := context.WithCancel(context.Background()) fmt.Println("Starting the test server on port", config.DiceConfig.RespServer.Port) diff --git a/internal/iothread/cmd_compose.go b/internal/commandhandler/cmd_compose.go similarity index 98% rename from internal/iothread/cmd_compose.go rename to internal/commandhandler/cmd_compose.go index 081871f47..6ea026550 100644 --- a/internal/iothread/cmd_compose.go +++ b/internal/commandhandler/cmd_compose.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package iothread +package commandhandler import ( "math" @@ -24,7 +24,7 @@ import ( "github.com/dicedb/dice/internal/ops" ) -// This file contains functions used by the IOThread to handle and process responses +// This file contains functions used by the CommandHandler to handle and process responses // from multiple shards during distributed operations. For commands that are executed // across several shards, such as MultiShard commands, dedicated functions are responsible // for aggregating and managing the results. diff --git a/internal/iothread/cmd_custom.go b/internal/commandhandler/cmd_custom.go similarity index 95% rename from internal/iothread/cmd_custom.go rename to internal/commandhandler/cmd_custom.go index 5dd568a92..2fcc88d5a 100644 --- a/internal/iothread/cmd_custom.go +++ b/internal/commandhandler/cmd_custom.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package iothread +package commandhandler import ( "fmt" @@ -28,7 +28,7 @@ import ( // RespAuth returns with an encoded "OK" if the user is authenticated // If the user is not authenticated, it returns with an encoded error message -func (t *BaseIOThread) RespAuth(args []string) interface{} { +func (h *BaseCommandHandler) RespAuth(args []string) interface{} { // Check for incorrect number of arguments (arity error). if len(args) < 1 || len(args) > 2 { return diceerrors.ErrWrongArgumentCount("AUTH") @@ -47,7 +47,7 @@ func (t *BaseIOThread) RespAuth(args []string) interface{} { username, password = args[0], args[1] } - if err := t.Session.Validate(username, password); err != nil { + if err := h.Session.Validate(username, password); err != nil { return err } diff --git a/internal/iothread/cmd_decompose.go b/internal/commandhandler/cmd_decompose.go similarity index 80% rename from internal/iothread/cmd_decompose.go rename to internal/commandhandler/cmd_decompose.go index 0ba439aa3..b9e881153 100644 --- a/internal/iothread/cmd_decompose.go +++ b/internal/commandhandler/cmd_decompose.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package iothread +package commandhandler import ( "context" @@ -28,7 +28,7 @@ import ( "github.com/dicedb/dice/internal/store" ) -// This file is utilized by the IOThread to decompose commands that need to be executed +// This file is utilized by the CommandHandler to decompose commands that need to be executed // across multiple shards. For commands that operate on multiple keys or necessitate // distribution across shards (e.g., MultiShard commands), a Breakup function is invoked // to transform the original command into multiple smaller commands, each directed at @@ -41,13 +41,13 @@ import ( // decomposeRename breaks down the RENAME command into separate DELETE and SET commands. // It first waits for the result of a GET command from shards. If successful, it removes // the old key using a DEL command and sets the new key with the retrieved value using a SET command. -func decomposeRename(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeRename(ctx context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { // Waiting for GET command response var val string select { case <-ctx.Done(): - slog.Error("IOThread timed out waiting for response from shards", slog.String("id", thread.id), slog.Any("error", ctx.Err())) - case preProcessedResp, ok := <-thread.preprocessingChan: + slog.Error("CommandHandler timed out waiting for response from shards", slog.String("id", h.id), slog.Any("error", ctx.Err())) + case preProcessedResp, ok := <-h.preprocessingChan: if ok { evalResp := preProcessedResp.EvalResponse if evalResp.Error != nil { @@ -85,13 +85,13 @@ func decomposeRename(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCm // decomposeCopy breaks down the COPY command into a SET command that copies a value from // one key to another. It first retrieves the value of the original key from shards, then // sets the value to the destination key using a SET command. -func decomposeCopy(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeCopy(ctx context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { // Waiting for GET command response var resp *ops.StoreResponse select { case <-ctx.Done(): - slog.Error("IOThread timed out waiting for response from shards", slog.String("id", thread.id), slog.Any("error", ctx.Err())) - case preProcessedResp, ok := <-thread.preprocessingChan: + slog.Error("CommandHandler timed out waiting for response from shards", slog.String("id", h.id), slog.Any("error", ctx.Err())) + case preProcessedResp, ok := <-h.preprocessingChan: if ok { resp = preProcessedResp } @@ -124,7 +124,7 @@ func decomposeCopy(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) // decomposeMSet decomposes the MSET (Multi-set) command into individual SET commands. // It expects an even number of arguments (key-value pairs). For each pair, it creates // a separate SET command to store the value at the given key. -func decomposeMSet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeMSet(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args)%2 != 0 { return nil, diceerrors.ErrWrongArgumentCount("MSET") } @@ -148,7 +148,7 @@ func decomposeMSet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cm // decomposeMGet decomposes the MGET (Multi-get) command into individual GET commands. // It expects a list of keys, and for each key, it creates a separate GET command to // retrieve the value associated with that key. -func decomposeMGet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeMGet(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 1 { return nil, diceerrors.ErrWrongArgumentCount("MGET") } @@ -164,7 +164,7 @@ func decomposeMGet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cm return decomposedCmds, nil } -func decomposeSInter(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeSInter(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 1 { return nil, diceerrors.ErrWrongArgumentCount("SINTER") } @@ -180,7 +180,7 @@ func decomposeSInter(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]* return decomposedCmds, nil } -func decomposeSDiff(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeSDiff(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 1 { return nil, diceerrors.ErrWrongArgumentCount("SDIFF") } @@ -196,7 +196,7 @@ func decomposeSDiff(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*c return decomposedCmds, nil } -func decomposeJSONMget(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeJSONMget(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) < 2 { return nil, diceerrors.ErrWrongArgumentCount("JSON.MGET") } @@ -215,7 +215,7 @@ func decomposeJSONMget(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([ return decomposedCmds, nil } -func decomposeTouch(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeTouch(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) == 0 { return nil, diceerrors.ErrWrongArgumentCount("TOUCH") } @@ -232,13 +232,13 @@ func decomposeTouch(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*c return decomposedCmds, nil } -func decomposeDBSize(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeDBSize(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) > 0 { return nil, diceerrors.ErrWrongArgumentCount("DBSIZE") } decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args)) - for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ { + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { decomposedCmds = append(decomposedCmds, &cmd.DiceDBCmd{ Cmd: store.SingleShardSize, @@ -249,13 +249,13 @@ func decomposeDBSize(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) return decomposedCmds, nil } -func decomposeKeys(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeKeys(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) != 1 { return nil, diceerrors.ErrWrongArgumentCount("KEYS") } decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args)) - for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ { + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { decomposedCmds = append(decomposedCmds, &cmd.DiceDBCmd{ Cmd: store.SingleShardKeys, @@ -266,13 +266,13 @@ func decomposeKeys(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ( return decomposedCmds, nil } -func decomposeFlushDB(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { +func decomposeFlushDB(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) { if len(cd.Args) > 1 { return nil, diceerrors.ErrWrongArgumentCount("FLUSHDB") } decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args)) - for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ { + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { decomposedCmds = append(decomposedCmds, &cmd.DiceDBCmd{ Cmd: store.FlushDB, diff --git a/internal/iothread/cmd_meta.go b/internal/commandhandler/cmd_meta.go similarity index 96% rename from internal/iothread/cmd_meta.go rename to internal/commandhandler/cmd_meta.go index 837a5be4b..2dd732b1b 100644 --- a/internal/iothread/cmd_meta.go +++ b/internal/commandhandler/cmd_meta.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package iothread +package commandhandler import ( "context" @@ -214,12 +214,12 @@ const ( type CmdMeta struct { CmdType - Cmd string - IOThreadHandler func([]string) []byte + Cmd string + CmdHandlerFunction func([]string) []byte // decomposeCommand is a function that takes a DiceDB command and breaks it down into smaller, // manageable DiceDB commands for each shard processing. It returns a slice of DiceDB commands. - decomposeCommand func(ctx context.Context, thread *BaseIOThread, DiceDBCmd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) + decomposeCommand func(ctx context.Context, h *BaseCommandHandler, DiceDBCmd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) // composeResponse is a function that combines multiple responses from the execution of commands // into a single response object. It accepts a variadic parameter of EvalResponse objects @@ -234,10 +234,10 @@ type CmdMeta struct { // preProcessResponse is a function that handles the preprocessing of a DiceDB command by // preparing the necessary operations (e.g., fetching values from shards) before the command - // is executed. It takes the io-thread and the original DiceDB command as parameters and + // is executed. It takes the CommandHandler and the original DiceDB command as parameters and // ensures that any required information is retrieved and processed in advance. Use this when set // preProcessingReq = true. - preProcessResponse func(thread *BaseIOThread, DiceDBCmd *cmd.DiceDBCmd) error + preProcessResponse func(h *BaseCommandHandler, DiceDBCmd *cmd.DiceDBCmd) error } var CommandsMeta = map[string]CmdMeta{ @@ -695,8 +695,8 @@ func init() { func validateCmdMeta(c string, meta CmdMeta) error { switch meta.CmdType { case Global: - if meta.IOThreadHandler == nil { - return fmt.Errorf("global command %s must have IOThreadHandler function", c) + if meta.CmdHandlerFunction == nil { + return fmt.Errorf("global command %s must have CmdHandlerFunction function", c) } case MultiShard, AllShard: if meta.decomposeCommand == nil || meta.composeResponse == nil { diff --git a/internal/iothread/cmd_preprocess.go b/internal/commandhandler/cmd_preprocess.go similarity index 87% rename from internal/iothread/cmd_preprocess.go rename to internal/commandhandler/cmd_preprocess.go index b5116622f..c985997b7 100644 --- a/internal/iothread/cmd_preprocess.go +++ b/internal/commandhandler/cmd_preprocess.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package iothread +package commandhandler import ( "github.com/dicedb/dice/internal/cmd" @@ -25,13 +25,13 @@ import ( // preProcessRename prepares the RENAME command for preprocessing by sending a GET command // to retrieve the value of the original key. The retrieved value is used later in the // decomposeRename function to delete the old key and set the new key. -func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { +func preProcessRename(h *BaseCommandHandler, diceDBCmd *cmd.DiceDBCmd) error { if len(diceDBCmd.Args) < 2 { return diceerrors.ErrWrongArgumentCount("RENAME") } key := diceDBCmd.Args[0] - sid, rc := thread.shardManager.GetShardInfo(key) + sid, rc := h.shardManager.GetShardInfo(key) preCmd := cmd.DiceDBCmd{ Cmd: "RENAME", @@ -42,7 +42,7 @@ func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { SeqID: 0, RequestID: GenerateUniqueRequestID(), Cmd: &preCmd, - IOThreadID: thread.id, + CmdHandlerID: h.id, ShardID: sid, Client: nil, PreProcessing: true, @@ -54,12 +54,12 @@ func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { // preProcessCopy prepares the COPY command for preprocessing by sending a GET command // to retrieve the value of the original key. The retrieved value is used later in the // decomposeCopy function to copy the value to the destination key. -func customProcessCopy(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { +func customProcessCopy(h *BaseCommandHandler, diceDBCmd *cmd.DiceDBCmd) error { if len(diceDBCmd.Args) < 2 { return diceerrors.ErrWrongArgumentCount("COPY") } - sid, rc := thread.shardManager.GetShardInfo(diceDBCmd.Args[0]) + sid, rc := h.shardManager.GetShardInfo(diceDBCmd.Args[0]) preCmd := cmd.DiceDBCmd{ Cmd: "COPY", @@ -71,7 +71,7 @@ func customProcessCopy(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error { SeqID: 0, RequestID: GenerateUniqueRequestID(), Cmd: &preCmd, - IOThreadID: thread.id, + CmdHandlerID: h.id, ShardID: sid, Client: nil, PreProcessing: true, diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go new file mode 100644 index 000000000..d9c719498 --- /dev/null +++ b/internal/commandhandler/commandhandler.go @@ -0,0 +1,519 @@ +package commandhandler + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "strconv" + "sync/atomic" + "syscall" + "time" + + "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/auth" + "github.com/dicedb/dice/internal/clientio" + "github.com/dicedb/dice/internal/clientio/requestparser" + "github.com/dicedb/dice/internal/cmd" + diceerrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/ops" + "github.com/dicedb/dice/internal/querymanager" + "github.com/dicedb/dice/internal/shard" + "github.com/dicedb/dice/internal/wal" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/google/uuid" +) + +const defaultRequestTimeout = 6 * time.Second + +var requestCounter uint32 + +type CommandHandler interface { + ID() string + Start(ctx context.Context) error + Stop() error +} + +type BaseCommandHandler struct { + CommandHandler + id string + parser requestparser.Parser + shardManager *shard.ShardManager + adhocReqChan chan *cmd.DiceDBCmd + Session *auth.Session + globalErrorChan chan error + ioThreadReadChan chan []byte // Channel to receive data from io-thread + ioThreadWriteChan chan interface{} // Channel to send data to io-thread + ioThreadErrChan chan error // Channel to receive errors from io-thread + responseChan chan *ops.StoreResponse // Channel to communicate with shard + preprocessingChan chan *ops.StoreResponse // Channel to communicate with shard + cmdWatchSubscriptionChan chan watchmanager.WatchSubscription + wl wal.AbstractWAL +} + +func NewCommandHandler(id string, responseChan, preprocessingChan chan *ops.StoreResponse, + cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, + parser requestparser.Parser, shardManager *shard.ShardManager, gec chan error, + ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, ioThreadErrChan chan error, + wl wal.AbstractWAL) *BaseCommandHandler { + return &BaseCommandHandler{ + id: id, + parser: parser, + shardManager: shardManager, + adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), + Session: auth.NewSession(), + globalErrorChan: gec, + ioThreadReadChan: ioThreadReadChan, + ioThreadWriteChan: ioThreadWriteChan, + ioThreadErrChan: ioThreadErrChan, + responseChan: responseChan, + preprocessingChan: preprocessingChan, + cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, + wl: wl, + } +} + +func (h *BaseCommandHandler) ID() string { + return h.id +} + +func (h *BaseCommandHandler) Start(ctx context.Context) error { + errChan := make(chan error, 1) // for adhoc request processing errors + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-h.ioThreadErrChan: + return err + case cmdReq := <-h.adhocReqChan: + resp, err := h.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout) + h.sendResponseToIOThread(resp, err) + case err := <-errChan: + return h.handleError(err) + case data := <-h.ioThreadReadChan: + resp, err := h.processCommand(ctx, &data, h.globalErrorChan) + h.sendResponseToIOThread(resp, err) + } + } +} + +// processCommand processes commands recevied from io thread +func (h *BaseCommandHandler) processCommand(ctx context.Context, data *[]byte, gec chan error) (interface{}, error) { + commands, err := h.parser.Parse(*data) + + if err != nil { + slog.Debug("error parsing commands from io thread", slog.String("id", h.id), slog.Any("error", err)) + return nil, err + } + + if len(commands) == 0 { + slog.Debug("invalid request from io thread with zero length", slog.String("id", h.id)) + return nil, fmt.Errorf("ERR: Invalid request") + } + + // DiceDB supports clients to send only one request at a time + // We also need to ensure that the client is blocked until the response is received + if len(commands) > 1 { + return nil, fmt.Errorf("ERR: Multiple commands not supported") + } + + err = h.isAuthenticated(commands[0]) + if err != nil { + slog.Debug("command handler authentication failed", slog.String("id", h.id), slog.Any("error", err)) + return nil, err + } + + return h.handleCmdRequestWithTimeout(ctx, gec, commands, false, defaultRequestTimeout) +} + +func (h *BaseCommandHandler) handleCmdRequestWithTimeout(ctx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) (interface{}, error) { + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return h.executeCommandHandler(execCtx, gec, commands, isWatchNotification) +} + +func (h *BaseCommandHandler) executeCommandHandler(execCtx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) { + // Retrieve metadata for the command to determine if multisharding is supported. + meta, ok := CommandsMeta[commands[0].Cmd] + if ok && meta.preProcessing { + if err := meta.preProcessResponse(h, commands[0]); err != nil { + slog.Debug("error pre processing response", slog.String("id", h.id), slog.Any("error", err)) + return nil, err + } + } + + resp, err := h.executeCommand(execCtx, commands[0], isWatchNotification) + + // log error and send to global error channel if it's a connection error + if err != nil { + slog.Error("Error executing command", slog.String("id", h.id), slog.Any("error", err)) + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { + slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) + gec <- err + } + } + + return resp, err +} + +func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) { + // Break down the single command into multiple commands if multisharding is supported. + // The length of cmdList helps determine how many shards to wait for responses. + cmdList := make([]*cmd.DiceDBCmd, 0) + var watchLabel string + + // Retrieve metadata for the command to determine if multisharding is supported. + meta, ok := CommandsMeta[diceDBCmd.Cmd] + if !ok { + // If no metadata exists, treat it as a single command and not migrated + cmdList = append(cmdList, diceDBCmd) + } else { + // Depending on the command type, decide how to handle it. + switch meta.CmdType { + case Global: + // process global command immediately without involving any shards. + return meta.CmdHandlerFunction(diceDBCmd.Args), nil + + case SingleShard: + // For single-shard or custom commands, process them without breaking up. + cmdList = append(cmdList, diceDBCmd) + + case MultiShard, AllShard: + var err error + // If the command supports multisharding, break it down into multiple commands. + cmdList, err = meta.decomposeCommand(ctx, h, diceDBCmd) + if err != nil { + slog.Debug("error decomposing command", slog.String("id", h.id), slog.Any("error", err)) + // Check if it's a CustomError + var customErr *diceerrors.PreProcessError + if errors.As(err, &customErr) { + return nil, err + } + return nil, err + } + + case Custom: + return h.handleCustomCommands(diceDBCmd) + + case Watch: + // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass + // it along as is. + // Modify the command name to remove the .WATCH suffix, this will allow us to generate a consistent + // fingerprint (which uses the command name without the suffix) + diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6] + + // check if the last argument is a watch label + label := diceDBCmd.Args[len(diceDBCmd.Args)-1] + if _, err := uuid.Parse(label); err == nil { + watchLabel = label + + // remove the watch label from the args + diceDBCmd.Args = diceDBCmd.Args[:len(diceDBCmd.Args)-1] + } + + watchCmd := &cmd.DiceDBCmd{ + Cmd: diceDBCmd.Cmd, + Args: diceDBCmd.Args, + } + cmdList = append(cmdList, watchCmd) + isWatchNotification = true + + case Unwatch: + // Generate the Cmd being unwatched. All we need to do is remove the .UNWATCH suffix from the command and pass + // it along as is. + // Modify the command name to remove the .UNWATCH suffix, this will allow us to generate a consistent + // fingerprint (which uses the command name without the suffix) + diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-8] + watchCmd := &cmd.DiceDBCmd{ + Cmd: diceDBCmd.Cmd, + Args: diceDBCmd.Args, + } + cmdList = append(cmdList, watchCmd) + isWatchNotification = false + } + } + + // Unsubscribe Unwatch command type + if meta.CmdType == Unwatch { + return h.handleCommandUnwatch(cmdList) + } + + // Scatter the broken-down commands to the appropriate shards. + if err := h.scatter(ctx, cmdList, meta.CmdType); err != nil { + return nil, err + } + + // Gather the responses from the shards and write them to the buffer. + resp, err := h.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel) + if err != nil { + return nil, err + } + + // Proceed to subscribe after successful execution + if meta.CmdType == Watch { + h.handleCommandWatch(cmdList) + } + + return resp, nil +} + +func (h *BaseCommandHandler) handleCustomCommands(diceDBCmd *cmd.DiceDBCmd) (interface{}, error) { + // if command is of type Custom, write a custom logic around it + switch diceDBCmd.Cmd { + case CmdAuth: + return h.RespAuth(diceDBCmd.Args), nil + case CmdEcho: + return RespEcho(diceDBCmd.Args), nil + case CmdAbort: + slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", h.id)) + h.globalErrorChan <- diceerrors.ErrAborted + return clientio.OK, nil + case CmdPing: + return RespPING(diceDBCmd.Args), nil + case CmdHello: + return RespHello(diceDBCmd.Args), nil + case CmdSleep: + return RespSleep(diceDBCmd.Args), nil + default: + return nil, diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) + } +} + +// handleCommandWatch sends a watch subscription request to the watch manager. +func (h *BaseCommandHandler) handleCommandWatch(cmdList []*cmd.DiceDBCmd) { + h.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: true, + WatchCmd: cmdList[len(cmdList)-1], + AdhocReqChan: h.adhocReqChan, + } +} + +// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client. +// The response is sent before the unwatch request is processed by the watch manager. +func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) (interface{}, error) { + // extract the fingerprint + command := cmdList[len(cmdList)-1] + fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32) + if parseErr != nil { + return nil, diceerrors.ErrInvalidFingerprint + } + + // send the unsubscribe request + h.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: false, + AdhocReqChan: h.adhocReqChan, + Fingerprint: uint32(fp), + } + + return clientio.OK, nil +} + +// scatter distributes the DiceDB commands to the respective shards based on the key. +// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing. +func (h *BaseCommandHandler) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd, cmdType CmdType) error { + // Otherwise check for the shard based on the key using hash + // and send it to the particular shard + // Check if the context has been canceled or expired. + select { + case <-ctx.Done(): + // If the context is canceled, return the error associated with it. + return ctx.Err() + default: + // Proceed with the default case when the context is not canceled. + + if cmdType == AllShard { + // If the command type is for all shards, iterate over all available shards. + for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ { + // Get the shard ID (i) and its associated request channel. + shardID, responseChan := i, h.shardManager.GetShard(i).ReqChan + + // Send a StoreOp operation to the shard's request channel. + responseChan <- &ops.StoreOp{ + SeqID: i, // Sequence ID for this operation. + RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. + Cmd: cmds[0], // Command to be executed, using the first command in cmds. + CmdHandlerID: h.id, // ID of the current command handler. + ShardID: shardID, // ID of the shard handling this operation. + Client: nil, // Client information (if applicable). + } + } + } else { + // If the command type is specific to certain commands, process them individually. + for i := uint8(0); i < uint8(len(cmds)); i++ { + // Determine the appropriate shard for the current command using a routing key. + shardID, responseChan := h.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i])) + + // Send a StoreOp operation to the shard's request channel. + responseChan <- &ops.StoreOp{ + SeqID: i, // Sequence ID for this operation. + RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. + Cmd: cmds[i], // Command to be executed, using the current command in cmds. + CmdHandlerID: h.id, // ID of the current command handler. + ShardID: shardID, // ID of the shard handling this operation. + Client: nil, // Client information (if applicable). + } + } + } + } + + return nil +} + +// getRoutingKeyFromCommand determines the key used for shard routing +func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string { + if len(diceDBCmd.Args) > 0 { + return diceDBCmd.Args[0] + } + return diceDBCmd.Cmd +} + +// gather collects the responses from multiple shards and writes the results into the provided buffer. +// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). +func (h *BaseCommandHandler) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) (interface{}, error) { + // Collect responses from all shards + storeOp, err := h.gatherResponses(ctx, numCmds) + if err != nil { + return nil, err + } + + if len(storeOp) == 0 { + slog.Error("No response from shards", + slog.String("id", h.id), + slog.String("command", diceDBCmd.Cmd)) + return nil, fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) + } + + if isWatchNotification { + return h.handleWatchNotification(diceDBCmd, storeOp[0], watchLabel) + } + + // Process command based on its type + cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd] + if !ok { + return h.handleUnsupportedCommand(storeOp[0]) + } + + return h.handleCommand(cmdMeta, diceDBCmd, storeOp) +} + +// gatherResponses collects responses from all shards +func (h *BaseCommandHandler) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) { + storeOp := make([]ops.StoreResponse, 0, numCmds) + + for numCmds > 0 { + select { + case <-ctx.Done(): + slog.Error("Timed out waiting for response from shards", + slog.String("id", h.id), + slog.Any("error", ctx.Err())) + return nil, ctx.Err() + + case resp, ok := <-h.responseChan: + if ok { + storeOp = append(storeOp, *resp) + } + numCmds-- + + case sError, ok := <-h.shardManager.ShardErrorChan: + if ok { + slog.Error("Error from shard", + slog.String("id", h.id), + slog.Any("error", sError)) + return nil, sError.Error + } + } + } + + return storeOp, nil +} + +// handleWatchNotification processes watch notification responses +func (h *BaseCommandHandler) handleWatchNotification(diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) (interface{}, error) { + fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint()) + + // if watch label is not empty, then this is the first response for the watch command + // hence, we will send the watch label as part of the response + firstRespElem := diceDBCmd.Cmd + if watchLabel != "" { + firstRespElem = watchLabel + } + + if resp.EvalResponse.Error != nil { + // This is a special case where error is returned as part of the watch response + return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error), nil + } + + return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result), nil +} + +// handleUnsupportedCommand processes commands not in CommandsMeta +func (h *BaseCommandHandler) handleUnsupportedCommand(resp ops.StoreResponse) (interface{}, error) { + if resp.EvalResponse.Error != nil { + return nil, resp.EvalResponse.Error + } + return resp.EvalResponse.Result, nil +} + +// handleCommand processes commands based on their type +func (h *BaseCommandHandler) handleCommand(cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) (interface{}, error) { + switch cmdMeta.CmdType { + case SingleShard, Custom: + if storeOp[0].EvalResponse.Error != nil { + return nil, storeOp[0].EvalResponse.Error + } else { + return storeOp[0].EvalResponse.Result, nil + } + + case MultiShard, AllShard: + return cmdMeta.composeResponse(storeOp...), nil + + default: + slog.Error("Unknown command type", + slog.String("id", h.id), + slog.String("command", diceDBCmd.Cmd), + slog.Any("evalResp", storeOp)) + return nil, diceerrors.ErrInternalServer + } +} + +func (h *BaseCommandHandler) handleError(err error) error { + if err != nil { + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) + return err + } + } + return fmt.Errorf("error writing response: %v", err) +} + +func (h *BaseCommandHandler) sendResponseToIOThread(resp interface{}, err error) { + if err != nil { + var customErr *diceerrors.PreProcessError + if errors.As(err, &customErr) { + h.ioThreadWriteChan <- customErr.Result + } + h.ioThreadWriteChan <- err + return + } + h.ioThreadWriteChan <- resp +} + +func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error { + if diceDBCmd.Cmd != auth.Cmd && !h.Session.IsActive() { + return errors.New("NOAUTH Authentication required") + } + + return nil +} + +func (h *BaseCommandHandler) Stop() error { + slog.Info("Stopping command handler", slog.String("id", h.id)) + h.Session.Expire() + return nil +} + +func GenerateUniqueRequestID() uint32 { + return atomic.AddUint32(&requestCounter, 1) +} diff --git a/internal/commandhandler/registry.go b/internal/commandhandler/registry.go new file mode 100644 index 000000000..1ce93eb31 --- /dev/null +++ b/internal/commandhandler/registry.go @@ -0,0 +1,74 @@ +package commandhandler + +import ( + "errors" + "sync" + "sync/atomic" + + "github.com/dicedb/dice/internal/shard" +) + +type Registry struct { + activeCmdHandlers sync.Map + numCmdHandlers atomic.Int32 + maxClients int32 + ShardManager *shard.ShardManager + mu sync.Mutex +} + +var ( + ErrMaxCmdHandlersReached = errors.New("maximum number of command handlers reached") + ErrCmdHandlerNotFound = errors.New("command handler not found") + ErrCmdHandlerNotBase = errors.New("command handler is not a BaseCommandHandler") +) + +func NewRegistry(maxClients int32, sm *shard.ShardManager) *Registry { + return &Registry{ + maxClients: maxClients, + ShardManager: sm, + } +} + +func (m *Registry) RegisterCommandHandler(cmdHandler *BaseCommandHandler) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.CommandHandlerCount() >= m.maxClients { + return ErrMaxCmdHandlersReached + } + + responseChan := cmdHandler.responseChan + preprocessingChan := cmdHandler.preprocessingChan + + if responseChan != nil && preprocessingChan != nil { + m.ShardManager.RegisterCommandHandler(cmdHandler.ID(), responseChan, preprocessingChan) // TODO: Change responseChan type to ShardResponse + } else if responseChan != nil && preprocessingChan == nil { + m.ShardManager.RegisterCommandHandler(cmdHandler.ID(), responseChan, nil) + } + + m.activeCmdHandlers.Store(cmdHandler.ID(), cmdHandler) + + m.numCmdHandlers.Add(1) + return nil +} + +func (m *Registry) CommandHandlerCount() int32 { + return m.numCmdHandlers.Load() +} + +func (m *Registry) UnregisterCommandHandler(id string) error { + m.ShardManager.UnregisterCommandHandler(id) + if cmdHandler, loaded := m.activeCmdHandlers.LoadAndDelete(id); loaded { + ch, ok := cmdHandler.(*BaseCommandHandler) + if !ok { + return ErrCmdHandlerNotBase + } + if err := ch.Stop(); err != nil { + return err + } + } else { + return ErrCmdHandlerNotFound + } + m.numCmdHandlers.Add(-1) + return nil +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 3d5b28e0a..04d7b1e3c 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -35,7 +35,7 @@ const ( WrongTypeErr = "-WRONGTYPE Operation against a key holding the wrong kind of value" WrongTypeHllErr = "-WRONGTYPE Key is not a valid HyperLogLog string value." InvalidHllErr = "-INVALIDOBJ Corrupted HLL object detected" - IOThreadNotFoundErr = "io-thread with ID %s not found" + CmdHandlerNotFoundErr = "command handler with ID %s not found" JSONPathValueTypeErr = "-WRONGTYPE wrong type of path value - expected string but found %s" HashValueNotIntegerErr = "hash value is not an integer" InternalServerError = "-ERR: Internal server error, unable to process command" diff --git a/internal/iothread/iothread.go b/internal/iothread/iothread.go index b7ef96ee6..d7c108235 100644 --- a/internal/iothread/iothread.go +++ b/internal/iothread/iothread.go @@ -18,35 +18,12 @@ package iothread import ( "context" - "errors" - "fmt" "log/slog" - "net" - "strconv" - "strings" - "sync/atomic" - "syscall" - "time" - "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" - "github.com/dicedb/dice/internal/clientio" "github.com/dicedb/dice/internal/clientio/iohandler" - "github.com/dicedb/dice/internal/clientio/requestparser" - "github.com/dicedb/dice/internal/cmd" - diceerrors "github.com/dicedb/dice/internal/errors" - "github.com/dicedb/dice/internal/ops" - "github.com/dicedb/dice/internal/querymanager" - "github.com/dicedb/dice/internal/shard" - "github.com/dicedb/dice/internal/wal" - "github.com/dicedb/dice/internal/watchmanager" - "github.com/google/uuid" ) -const defaultRequestTimeout = 6 * time.Second - -var requestCounter uint32 - // IOThread interface type IOThread interface { ID() string @@ -56,35 +33,23 @@ type IOThread interface { type BaseIOThread struct { IOThread - id string - ioHandler iohandler.IOHandler - parser requestparser.Parser - shardManager *shard.ShardManager - adhocReqChan chan *cmd.DiceDBCmd - Session *auth.Session - globalErrorChan chan error - responseChan chan *ops.StoreResponse - preprocessingChan chan *ops.StoreResponse - cmdWatchSubscriptionChan chan watchmanager.WatchSubscription - wl wal.AbstractWAL + id string + ioHandler iohandler.IOHandler + Session *auth.Session + ioThreadReadChan chan []byte // Channel to send data to the command handler + ioThreadWriteChan chan interface{} // Channel to receive data from the command handler + ioThreadErrChan chan error // Channel to receive errors from the ioHandler } -func NewIOThread(wid string, responseChan, preprocessingChan chan *ops.StoreResponse, - cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, - ioHandler iohandler.IOHandler, parser requestparser.Parser, - shardManager *shard.ShardManager, gec chan error, wl wal.AbstractWAL) *BaseIOThread { +func NewIOThread(id string, ioHandler iohandler.IOHandler, + ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, ioThreadErrChan chan error) *BaseIOThread { return &BaseIOThread{ - id: wid, - ioHandler: ioHandler, - parser: parser, - shardManager: shardManager, - globalErrorChan: gec, - responseChan: responseChan, - preprocessingChan: preprocessingChan, - Session: auth.NewSession(), - adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), - cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, - wl: wl, + id: id, + ioHandler: ioHandler, + Session: auth.NewSession(), + ioThreadReadChan: ioThreadReadChan, + ioThreadWriteChan: ioThreadWriteChan, + ioThreadErrChan: ioThreadErrChan, } } @@ -93,9 +58,9 @@ func (t *BaseIOThread) ID() string { } func (t *BaseIOThread) Start(ctx context.Context) error { - errChan := make(chan error, 1) - incomingDataChan := make(chan []byte) - readErrChan := make(chan error) + // local channels to communicate between Start and startInputReader goroutine + incomingDataChan := make(chan []byte) // data channel + readErrChan := make(chan error) // error channel runCtx, runCancel := context.WithCancel(ctx) defer runCancel() @@ -111,17 +76,17 @@ func (t *BaseIOThread) Start(ctx context.Context) error { slog.Warn("Error stopping io-thread:", slog.String("id", t.id), slog.Any("error", err)) } return ctx.Err() - case err := <-errChan: - return t.handleError(err) - case cmdReq := <-t.adhocReqChan: - t.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout) case data := <-incomingDataChan: - if err := t.processIncomingData(ctx, &data, errChan); err != nil { - return err - } + t.ioThreadReadChan <- data case err := <-readErrChan: slog.Debug("Read error in io-thread, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) + t.ioThreadErrChan <- err return err + case resp := <-t.ioThreadWriteChan: + err := t.ioHandler.Write(ctx, resp) + if err != nil { + slog.Debug("Error sending response to client", slog.String("id", t.id), slog.Any("error", err)) + } } } } @@ -149,484 +114,8 @@ func (t *BaseIOThread) startInputReader(ctx context.Context, incomingDataChan ch } } -func (t *BaseIOThread) handleError(err error) error { - if err != nil { - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { - slog.Debug("Connection closed for io-thread", slog.String("id", t.id), slog.Any("error", err)) - return err - } - } - return fmt.Errorf("error writing response: %v", err) -} - -func (t *BaseIOThread) processIncomingData(ctx context.Context, data *[]byte, errChan chan error) error { - commands, err := t.parser.Parse(*data) - - if err != nil { - err = t.ioHandler.Write(ctx, err) - if err != nil { - slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) - return err - } - return nil - } - - if len(commands) == 0 { - err = t.ioHandler.Write(ctx, fmt.Errorf("ERR: Invalid request")) - if err != nil { - slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) - return err - } - return nil - } - - // DiceDB supports clients to send only one request at a time - // We also need to ensure that the client is blocked until the response is received - if len(commands) > 1 { - err = t.ioHandler.Write(ctx, fmt.Errorf("ERR: Multiple commands not supported")) - if err != nil { - slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err)) - return err - } - } - - err = t.isAuthenticated(commands[0]) - if err != nil { - writeErr := t.ioHandler.Write(ctx, err) - if writeErr != nil { - slog.Debug("Write error, connection closed possibly", slog.Any("error", errors.Join(err, writeErr))) - return errors.Join(err, writeErr) - } - return nil - } - - t.handleCmdRequestWithTimeout(ctx, errChan, commands, false, defaultRequestTimeout) - return nil -} - -func (t *BaseIOThread) handleCmdRequestWithTimeout(ctx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) { - execCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - t.executeCommandHandler(execCtx, errChan, commands, isWatchNotification) -} - -func (t *BaseIOThread) executeCommandHandler(execCtx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) { - // Retrieve metadata for the command to determine if multisharding is supported. - meta, ok := CommandsMeta[commands[0].Cmd] - if ok && meta.preProcessing { - if err := meta.preProcessResponse(t, commands[0]); err != nil { - e := t.ioHandler.Write(execCtx, err) - if e != nil { - slog.Debug("Error executing for io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - } - } - - err := t.executeCommand(execCtx, commands[0], isWatchNotification) - if err != nil { - slog.Error("Error executing command", slog.String("id", t.id), slog.Any("error", err)) - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { - slog.Debug("Connection closed for io-thread", slog.String("id", t.id), slog.Any("error", err)) - errChan <- err - } - } -} - -func (t *BaseIOThread) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error { - // Break down the single command into multiple commands if multisharding is supported. - // The length of cmdList helps determine how many shards to wait for responses. - cmdList := make([]*cmd.DiceDBCmd, 0) - var watchLabel string - - // Retrieve metadata for the command to determine if multisharding is supported. - meta, ok := CommandsMeta[diceDBCmd.Cmd] - if !ok { - // If no metadata exists, treat it as a single command and not migrated - cmdList = append(cmdList, diceDBCmd) - } else { - // Depending on the command type, decide how to handle it. - switch meta.CmdType { - case Global: - // If it's a global command, process it immediately without involving any shards. - err := t.ioHandler.Write(ctx, meta.IOThreadHandler(diceDBCmd.Args)) - slog.Debug("Error executing command on io-thread", slog.String("id", t.id), slog.Any("error", err)) - return err - - case SingleShard: - // For single-shard or custom commands, process them without breaking up. - cmdList = append(cmdList, diceDBCmd) - - case MultiShard, AllShard: - var err error - // If the command supports multisharding, break it down into multiple commands. - cmdList, err = meta.decomposeCommand(ctx, t, diceDBCmd) - if err != nil { - var ioErr error - // Check if it's a CustomError - var customErr *diceerrors.PreProcessError - if errors.As(err, &customErr) { - ioErr = t.ioHandler.Write(ctx, customErr.Result) - } else { - ioErr = t.ioHandler.Write(ctx, err) - } - if ioErr != nil { - slog.Debug("Error executing for io-thread", slog.String("id", t.id), slog.Any("error", ioErr)) - } - return ioErr - } - - case Custom: - return t.handleCustomCommands(ctx, diceDBCmd) - - case Watch: - // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass - // it along as is. - // Modify the command name to remove the .WATCH suffix, this will allow us to generate a consistent - // fingerprint (which uses the command name without the suffix) - diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6] - - // check if the last argument is a watch label - label := diceDBCmd.Args[len(diceDBCmd.Args)-1] - if _, err := uuid.Parse(label); err == nil { - watchLabel = label - - // remove the watch label from the args - diceDBCmd.Args = diceDBCmd.Args[:len(diceDBCmd.Args)-1] - } - - watchCmd := &cmd.DiceDBCmd{ - Cmd: diceDBCmd.Cmd, - Args: diceDBCmd.Args, - } - cmdList = append(cmdList, watchCmd) - isWatchNotification = true - - case Unwatch: - // Generate the Cmd being unwatched. All we need to do is remove the .UNWATCH suffix from the command and pass - // it along as is. - // Modify the command name to remove the .UNWATCH suffix, this will allow us to generate a consistent - // fingerprint (which uses the command name without the suffix) - diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-8] - watchCmd := &cmd.DiceDBCmd{ - Cmd: diceDBCmd.Cmd, - Args: diceDBCmd.Args, - } - cmdList = append(cmdList, watchCmd) - isWatchNotification = false - } - } - - // Unsubscribe Unwatch command type - if meta.CmdType == Unwatch { - return t.handleCommandUnwatch(ctx, cmdList) - } - - // Scatter the broken-down commands to the appropriate shards. - if err := t.scatter(ctx, cmdList, meta.CmdType); err != nil { - return err - } - - // Gather the responses from the shards and write them to the buffer. - if err := t.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel); err != nil { - return err - } - - if meta.CmdType == Watch { - // Proceed to subscribe after successful execution - t.handleCommandWatch(cmdList) - } - - return nil -} - -func (t *BaseIOThread) handleCustomCommands(ctx context.Context, diceDBCmd *cmd.DiceDBCmd) error { - // if command is of type Custom, write a custom logic around it - switch diceDBCmd.Cmd { - case CmdAuth: - err := t.ioHandler.Write(ctx, t.RespAuth(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending auth response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdEcho: - err := t.ioHandler.Write(ctx, RespEcho(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending echo response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdAbort: - err := t.ioHandler.Write(ctx, clientio.OK) - if err != nil { - slog.Error("Error sending abort response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", t.id)) - t.globalErrorChan <- diceerrors.ErrAborted - return err - case CmdPing: - err := t.ioHandler.Write(ctx, RespPING(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdHello: - err := t.ioHandler.Write(ctx, RespHello(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - case CmdSleep: - err := t.ioHandler.Write(ctx, RespSleep(diceDBCmd.Args)) - if err != nil { - slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err)) - } - return err - default: - return diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) - } -} - -// handleCommandWatch sends a watch subscription request to the watch manager. -func (t *BaseIOThread) handleCommandWatch(cmdList []*cmd.DiceDBCmd) { - t.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ - Subscribe: true, - WatchCmd: cmdList[len(cmdList)-1], - AdhocReqChan: t.adhocReqChan, - } -} - -// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client. -// The response is sent before the unwatch request is processed by the watch manager. -func (t *BaseIOThread) handleCommandUnwatch(ctx context.Context, cmdList []*cmd.DiceDBCmd) error { - // extract the fingerprint - command := cmdList[len(cmdList)-1] - fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32) - if parseErr != nil { - err := t.ioHandler.Write(ctx, diceerrors.ErrInvalidFingerprint) - if err != nil { - return fmt.Errorf("error sending push response to client: %v", err) - } - return parseErr - } - - // send the unsubscribe request - t.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ - Subscribe: false, - AdhocReqChan: t.adhocReqChan, - Fingerprint: uint32(fp), - } - - err := t.ioHandler.Write(ctx, clientio.RespOK) - if err != nil { - return fmt.Errorf("error sending push response to client: %v", err) - } - return nil -} - -// scatter distributes the DiceDB commands to the respective shards based on the key. -// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing. -func (t *BaseIOThread) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd, cmdType CmdType) error { - // Otherwise check for the shard based on the key using hash - // and send it to the particular shard - // Check if the context has been canceled or expired. - select { - case <-ctx.Done(): - // If the context is canceled, return the error associated with it. - return ctx.Err() - default: - // Proceed with the default case when the context is not canceled. - - if cmdType == AllShard { - // If the command type is for all shards, iterate over all available shards. - for i := uint8(0); i < uint8(t.shardManager.GetShardCount()); i++ { - // Get the shard ID (i) and its associated request channel. - shardID, responseChan := i, t.shardManager.GetShard(i).ReqChan - - // Send a StoreOp operation to the shard's request channel. - responseChan <- &ops.StoreOp{ - SeqID: i, // Sequence ID for this operation. - RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. - Cmd: cmds[0], // Command to be executed, using the first command in cmds. - IOThreadID: t.id, // ID of the current io-thread. - ShardID: shardID, // ID of the shard handling this operation. - Client: nil, // Client information (if applicable). - } - } - } else { - // If the command type is specific to certain commands, process them individually. - for i := uint8(0); i < uint8(len(cmds)); i++ { - // Determine the appropriate shard for the current command using a routing key. - shardID, responseChan := t.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i])) - - // Send a StoreOp operation to the shard's request channel. - responseChan <- &ops.StoreOp{ - SeqID: i, // Sequence ID for this operation. - RequestID: GenerateUniqueRequestID(), // Unique identifier for the request. - Cmd: cmds[i], // Command to be executed, using the current command in cmds. - IOThreadID: t.id, // ID of the current io-thread. - ShardID: shardID, // ID of the shard handling this operation. - Client: nil, // Client information (if applicable). - } - } - } - } - - return nil -} - -// getRoutingKeyFromCommand determines the key used for shard routing -func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string { - if len(diceDBCmd.Args) > 0 { - return diceDBCmd.Args[0] - } - return diceDBCmd.Cmd -} - -// gather collects the responses from multiple shards and writes the results into the provided buffer. -// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (t *BaseIOThread) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) error { - // Collect responses from all shards - storeOp, err := t.gatherResponses(ctx, numCmds) - if err != nil { - return err - } - - if len(storeOp) == 0 { - slog.Error("No response from shards", - slog.String("id", t.id), - slog.String("command", diceDBCmd.Cmd)) - return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) - } - - if isWatchNotification { - return t.handleWatchNotification(ctx, diceDBCmd, storeOp[0], watchLabel) - } - - // Process command based on its type - cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd] - if !ok { - return t.handleUnsupportedCommand(ctx, storeOp[0]) - } - - return t.handleCommand(ctx, cmdMeta, diceDBCmd, storeOp) -} - -// gatherResponses collects responses from all shards -func (t *BaseIOThread) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) { - storeOp := make([]ops.StoreResponse, 0, numCmds) - - for numCmds > 0 { - select { - case <-ctx.Done(): - slog.Error("Timed out waiting for response from shards", - slog.String("id", t.id), - slog.Any("error", ctx.Err())) - return nil, ctx.Err() - - case resp, ok := <-t.responseChan: - if ok { - storeOp = append(storeOp, *resp) - } - numCmds-- - - case sError, ok := <-t.shardManager.ShardErrorChan: - if ok { - slog.Error("Error from shard", - slog.String("id", t.id), - slog.Any("error", sError)) - return nil, sError.Error - } - } - } - - return storeOp, nil -} - -// handleWatchNotification processes watch notification responses -func (t *BaseIOThread) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) error { - fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint()) - - // if watch label is not empty, then this is the first response for the watch command - // hence, we will send the watch label as part of the response - firstRespElem := diceDBCmd.Cmd - if watchLabel != "" { - firstRespElem = watchLabel - } - - if resp.EvalResponse.Error != nil { - return t.writeResponse(ctx, querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error)) - } - - return t.writeResponse(ctx, querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result)) -} - -// handleUnsupportedCommand processes commands not in CommandsMeta -func (t *BaseIOThread) handleUnsupportedCommand(ctx context.Context, resp ops.StoreResponse) error { - if resp.EvalResponse.Error != nil { - return t.writeResponse(ctx, resp.EvalResponse.Error) - } - return t.writeResponse(ctx, resp.EvalResponse.Result) -} - -// handleCommand processes commands based on their type -func (t *BaseIOThread) handleCommand(ctx context.Context, cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error { - var err error - - switch cmdMeta.CmdType { - case SingleShard, Custom: - if storeOp[0].EvalResponse.Error != nil { - err = t.writeResponse(ctx, storeOp[0].EvalResponse.Error) - } else { - err = t.writeResponse(ctx, storeOp[0].EvalResponse.Result) - } - - if err == nil && t.wl != nil { - if err := t.wl.LogCommand([]byte(fmt.Sprintf("%s %s", diceDBCmd.Cmd, strings.Join(diceDBCmd.Args, " ")))); err != nil { - return err - } - } - case MultiShard, AllShard: - err = t.writeResponse(ctx, cmdMeta.composeResponse(storeOp...)) - - if err == nil && t.wl != nil { - if err := t.wl.LogCommand([]byte(fmt.Sprintf("%s %s", diceDBCmd.Cmd, strings.Join(diceDBCmd.Args, " ")))); err != nil { - return err - } - } - default: - slog.Error("Unknown command type", - slog.String("id", t.id), - slog.String("command", diceDBCmd.Cmd), - slog.Any("evalResp", storeOp)) - err = t.writeResponse(ctx, diceerrors.ErrInternalServer) - } - return err -} - -// writeResponse handles writing responses and logging errors -func (t *BaseIOThread) writeResponse(ctx context.Context, response interface{}) error { - err := t.ioHandler.Write(ctx, response) - if err != nil { - slog.Debug("Error sending response to client", - slog.String("id", t.id), - slog.Any("error", err)) - } - return err -} - -func (t *BaseIOThread) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error { - if diceDBCmd.Cmd != auth.Cmd && !t.Session.IsActive() { - return errors.New("NOAUTH Authentication required") - } - - return nil -} - func (t *BaseIOThread) Stop() error { slog.Info("Stopping io-thread", slog.String("id", t.id)) t.Session.Expire() return nil } - -func GenerateUniqueRequestID() uint32 { - return atomic.AddUint32(&requestCounter, 1) -} diff --git a/internal/iothread/manager.go b/internal/iothread/manager.go index 3cf95eaed..9bcfcb451 100644 --- a/internal/iothread/manager.go +++ b/internal/iothread/manager.go @@ -20,15 +20,12 @@ import ( "errors" "sync" "sync/atomic" - - "github.com/dicedb/dice/internal/shard" ) type Manager struct { connectedClients sync.Map numIOThreads atomic.Int32 maxClients int32 - shardManager *shard.ShardManager mu sync.Mutex } @@ -37,10 +34,9 @@ var ( ErrIOThreadNotFound = errors.New("io-thread not found") ) -func NewManager(maxClients int32, sm *shard.ShardManager) *Manager { +func NewManager(maxClients int32) *Manager { return &Manager{ - maxClients: maxClients, - shardManager: sm, + maxClients: maxClients, } } @@ -53,14 +49,6 @@ func (m *Manager) RegisterIOThread(ioThread IOThread) error { } m.connectedClients.Store(ioThread.ID(), ioThread) - responseChan := ioThread.(*BaseIOThread).responseChan - preprocessingChan := ioThread.(*BaseIOThread).preprocessingChan - - if responseChan != nil && preprocessingChan != nil { - m.shardManager.RegisterIOThread(ioThread.ID(), responseChan, preprocessingChan) // TODO: Change responseChan type to ShardResponse - } else if responseChan != nil && preprocessingChan == nil { - m.shardManager.RegisterIOThread(ioThread.ID(), responseChan, nil) - } m.numIOThreads.Add(1) return nil @@ -88,8 +76,6 @@ func (m *Manager) UnregisterIOThread(id string) error { return ErrIOThreadNotFound } - m.shardManager.UnregisterIOThread(id) m.numIOThreads.Add(-1) - return nil } diff --git a/internal/ops/store_op.go b/internal/ops/store_op.go index 376223a00..236f63fea 100644 --- a/internal/ops/store_op.go +++ b/internal/ops/store_op.go @@ -27,8 +27,8 @@ type StoreOp struct { RequestID uint32 // RequestID identifies the request that this StoreOp belongs to Cmd *cmd.DiceDBCmd // Cmd is the atomic Store command (e.g., GET, SET) ShardID uint8 // ShardID of the shard on which the Store command will be executed - IOThreadID string // IOThreadID is the ID of the io-thread that sent this Store operation - Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the IOThreadID in the future + CmdHandlerID string // CmdHandlerID is the ID of the command handler that sent this Store operation + Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the CmdHandlerID in the future HTTPOp bool // HTTPOp is true if this Store operation is an HTTP operation WebsocketOp bool // WebsocketOp is true if this Store operation is a Websocket operation PreProcessing bool // PreProcessing indicates whether a comamnd operation requires preprocessing before execution. This is mainly used is multi-step-multi-shard commands diff --git a/internal/server/httpws/httpServer.go b/internal/server/httpws/httpServer.go index ab24bce51..f4b82da6d 100644 --- a/internal/server/httpws/httpServer.go +++ b/internal/server/httpws/httpServer.go @@ -28,7 +28,7 @@ import ( "sync" "time" - "github.com/dicedb/dice/internal/iothread" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/eval" "github.com/dicedb/dice/internal/server/abstractserver" @@ -44,8 +44,9 @@ import ( ) const ( - Abort = "ABORT" - stringNil = "(nil)" + Abort = "ABORT" + stringNil = "(nil)" + httpCmdHandlerID = "httpServer" ) var unimplementedCommands = map[string]bool{ @@ -113,7 +114,7 @@ func (s *HTTPServer) Run(ctx context.Context) error { httpCtx, cancelHTTP := context.WithCancel(ctx) defer cancelHTTP() - s.shardManager.RegisterIOThread("httpServer", s.ioChan, nil) + s.shardManager.RegisterCommandHandler(httpCmdHandlerID, s.ioChan, nil) wg.Add(1) go func() { @@ -157,7 +158,7 @@ func (s *HTTPServer) DiceHTTPHandler(writer http.ResponseWriter, request *http.R return } - if iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.MultiShard { + if commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.MultiShard { writeErrorResponse(writer, http.StatusBadRequest, "unsupported command", "Unsupported command received", slog.String("cmd", diceDBCmd.Cmd)) return @@ -179,10 +180,10 @@ func (s *HTTPServer) DiceHTTPHandler(writer http.ResponseWriter, request *http.R // send request to Shard Manager s.shardManager.GetShard(0).ReqChan <- &ops.StoreOp{ - Cmd: diceDBCmd, - IOThreadID: "httpServer", - ShardID: 0, - HTTPOp: true, + Cmd: diceDBCmd, + CmdHandlerID: httpCmdHandlerID, + ShardID: 0, + HTTPOp: true, } // Wait for response @@ -230,11 +231,11 @@ func (s *HTTPServer) DiceHTTPQwatchHandler(writer http.ResponseWriter, request * qwatchClient := comm.NewHTTPQwatchClient(s.qwatchResponseChan, clientIdentifierID) // Prepare the store operation storeOp := &ops.StoreOp{ - Cmd: diceDBCmd, - IOThreadID: "httpServer", - ShardID: 0, - Client: qwatchClient, - HTTPOp: true, + Cmd: diceDBCmd, + CmdHandlerID: httpCmdHandlerID, + ShardID: 0, + Client: qwatchClient, + HTTPOp: true, } slog.Info("Registered client for watching query", slog.Any("clientID", clientIdentifierID), @@ -354,9 +355,9 @@ func (s *HTTPServer) writeResponse(writer http.ResponseWriter, result *ops.Store // Check if the command is migrated, if it is we use EvalResponse values // else we use RESPParser to decode the response - _, ok := iothread.CommandsMeta[diceDBCmd.Cmd] + _, ok := commandhandler.CommandsMeta[diceDBCmd.Cmd] // TODO: Remove this conditional check and if (true) condition when all commands are migrated - if !ok || iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.Custom { + if !ok || commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.Custom { responseValue, err = DecodeEvalResponse(result.EvalResponse) if err != nil { slog.Error("Error decoding response", "error", err) diff --git a/internal/server/httpws/websocketServer.go b/internal/server/httpws/websocketServer.go index 89e8ed4a1..f1fb8b142 100644 --- a/internal/server/httpws/websocketServer.go +++ b/internal/server/httpws/websocketServer.go @@ -30,7 +30,7 @@ import ( "syscall" "time" - "github.com/dicedb/dice/internal/iothread" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/server/abstractserver" "github.com/dicedb/dice/internal/wal" @@ -46,9 +46,12 @@ import ( "golang.org/x/exp/rand" ) -const Qwatch = "Q.WATCH" -const Qunwatch = "Q.UNWATCH" -const Subscribe = "SUBSCRIBE" +const ( + Qwatch = "Q.WATCH" + Qunwatch = "Q.UNWATCH" + Subscribe = "SUBSCRIBE" + wsCmdHandlerID = "wsServer" +) var unimplementedCommandsWebsocket = map[string]bool{ Qunwatch: true, @@ -96,7 +99,7 @@ func (s *WebsocketServer) Run(ctx context.Context) error { websocketCtx, cancelWebsocket := context.WithCancel(ctx) defer cancelWebsocket() - s.shardManager.RegisterIOThread("wsServer", s.ioChan, nil) + s.shardManager.RegisterCommandHandler(wsCmdHandlerID, s.ioChan, nil) wg.Add(1) go func() { @@ -170,7 +173,7 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques continue } - if iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.MultiShard { + if commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.MultiShard { if err := WriteResponseWithRetries(conn, []byte("error: unsupported command"), maxRetries); err != nil { slog.Debug(fmt.Sprintf("Error writing message: %v", err)) } @@ -192,10 +195,10 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques // create request sp := &ops.StoreOp{ - Cmd: diceDBCmd, - IOThreadID: "wsServer", - ShardID: 0, - WebsocketOp: true, + Cmd: diceDBCmd, + CmdHandlerID: wsCmdHandlerID, + ShardID: 0, + WebsocketOp: true, } // handle q.watch commands @@ -295,7 +298,7 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D var responseValue interface{} // Check if the command is migrated, if it is we use EvalResponse values // else we use RESPParser to decode the response - _, ok := iothread.CommandsMeta[diceDBCmd.Cmd] + _, ok := commandhandler.CommandsMeta[diceDBCmd.Cmd] // TODO: Remove this conditional check and if (true) condition when all commands are migrated if !ok { responseValue, err = DecodeEvalResponse(response.EvalResponse) diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index 9a23184dc..3c83b71d1 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -27,6 +27,8 @@ import ( "syscall" "time" + "github.com/dicedb/dice/internal/commandhandler" + "github.com/dicedb/dice/internal/ops" "github.com/dicedb/dice/internal/server/abstractserver" "github.com/dicedb/dice/internal/wal" @@ -37,13 +39,13 @@ import ( "github.com/dicedb/dice/internal/clientio/iohandler/netconn" respparser "github.com/dicedb/dice/internal/clientio/requestparser/resp" "github.com/dicedb/dice/internal/iothread" - "github.com/dicedb/dice/internal/ops" "github.com/dicedb/dice/internal/shard" ) var ( - ioThreadCounter uint64 - startTime = time.Now().UnixNano() / int64(time.Millisecond) + ioThreadCounter uint64 + cmdHandlerCounter uint64 + startTime = time.Now().UnixNano() / int64(time.Millisecond) ) var ( @@ -61,6 +63,7 @@ type Server struct { serverFD int connBacklogSize int ioThreadManager *iothread.Manager + cmdHandlerManager *commandhandler.Registry shardManager *shard.ShardManager watchManager *watchmanager.Manager cmdWatchSubscriptionChan chan watchmanager.WatchSubscription @@ -68,13 +71,15 @@ type Server struct { wl wal.AbstractWAL } -func NewServer(shardManager *shard.ShardManager, ioThreadManager *iothread.Manager, - cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, wl wal.AbstractWAL) *Server { +func NewServer(shardManager *shard.ShardManager, ioThreadManager *iothread.Manager, cmdHandlerManager *commandhandler.Registry, + cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, + globalErrChan chan error, wl wal.AbstractWAL) *Server { return &Server{ Host: config.DiceConfig.RespServer.Addr, Port: config.DiceConfig.RespServer.Port, connBacklogSize: DefaultConnBacklogSize, ioThreadManager: ioThreadManager, + cmdHandlerManager: cmdHandlerManager, shardManager: shardManager, watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan, cmdWatchChan), cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, @@ -207,13 +212,22 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou return err } - parser := respparser.NewParser() + // create a new io-thread + ioThreadID := GenerateUniqueIOThreadID() + ioThreadReadChan := make(chan []byte) // for sending data to the command handler from the io-thread + ioThreadWriteChan := make(chan interface{}) // for sending data to the io-thread from the command handler + ioThreadErrChan := make(chan error, 1) // for receiving errors from the io-thread + thread := iothread.NewIOThread(ioThreadID, ioHandler, ioThreadReadChan, ioThreadWriteChan, ioThreadErrChan) + // For each io-thread, we create a dedicated command handler - 1:1 mapping + cmdHandlerID := GenerateUniqueCommandHandlerID() + parser := respparser.NewParser() responseChan := make(chan *ops.StoreResponse) // responseChan is used for handling common responses from shards preprocessingChan := make(chan *ops.StoreResponse) // preprocessingChan is specifically for handling responses from shards for commands that require preprocessing - ioThreadID := GenerateUniqueIOThreadID() - thread := iothread.NewIOThread(ioThreadID, responseChan, preprocessingChan, s.cmdWatchSubscriptionChan, ioHandler, parser, s.shardManager, s.globalErrorChan, s.wl) + handler := commandhandler.NewCommandHandler(cmdHandlerID, responseChan, preprocessingChan, + s.cmdWatchSubscriptionChan, parser, s.shardManager, s.globalErrorChan, + ioThreadReadChan, ioThreadWriteChan, ioThreadErrChan, s.wl) // Register the io-thread with the manager err = s.ioThreadManager.RegisterIOThread(thread) @@ -223,6 +237,15 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou wg.Add(1) go s.startIOThread(ctx, wg, thread) + + // Register the command handler with the manager + err = s.cmdHandlerManager.RegisterCommandHandler(handler) + if err != nil { + return err + } + + wg.Add(1) + go s.startCommandHandler(ctx, wg, handler) } } } @@ -243,12 +266,34 @@ func (s *Server) startIOThread(ctx context.Context, wg *sync.WaitGroup, thread * } } +func (s *Server) startCommandHandler(ctx context.Context, wg *sync.WaitGroup, cmdHandler *commandhandler.BaseCommandHandler) { + wg.Done() + defer func(wm *commandhandler.Registry, id string) { + err := wm.UnregisterCommandHandler(id) + if err != nil { + slog.Warn("Failed to unregister command handler", slog.String("id", id), slog.Any("error", err)) + } + }(s.cmdHandlerManager, cmdHandler.ID()) + ctx2, cancel := context.WithCancel(ctx) + defer cancel() + err := cmdHandler.Start(ctx2) + if err != nil { + slog.Debug("CommandHandler stopped", slog.String("id", cmdHandler.ID()), slog.Any("error", err)) + } +} + func GenerateUniqueIOThreadID() string { count := atomic.AddUint64(&ioThreadCounter, 1) timestamp := time.Now().UnixNano()/int64(time.Millisecond) - startTime return fmt.Sprintf("W-%d-%d", timestamp, count) } +func GenerateUniqueCommandHandlerID() string { + count := atomic.AddUint64(&cmdHandlerCounter, 1) + timestamp := time.Now().UnixNano()/int64(time.Millisecond) - startTime + return fmt.Sprintf("W-%d-%d", timestamp, count) +} + func (s *Server) Shutdown() { // Not implemented } diff --git a/internal/shard/shard_manager.go b/internal/shard/shard_manager.go index 7fdc83e66..4ff603817 100644 --- a/internal/shard/shard_manager.go +++ b/internal/shard/shard_manager.go @@ -120,15 +120,16 @@ func (manager *ShardManager) GetShard(id ShardID) *ShardThread { return nil } -// RegisterIOThread registers a io-thread with all Shards present in the ShardManager. -func (manager *ShardManager) RegisterIOThread(id string, request, processing chan *ops.StoreResponse) { +// RegisterCommandHandler registers a command handler with all Shards present in the ShardManager. +func (manager *ShardManager) RegisterCommandHandler(id string, request, processing chan *ops.StoreResponse) { for _, shard := range manager.shards { - shard.registerIOThread(id, request, processing) + shard.registerCommandHandler(id, request, processing) } } -func (manager *ShardManager) UnregisterIOThread(id string) { +// UnregisterCommandHandler unregisters a command handler from all Shards present in the ShardManager. +func (manager *ShardManager) UnregisterCommandHandler(id string) { for _, shard := range manager.shards { - shard.unregisterIOThread(id) + shard.unregisterCommandHandler(id) } } diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go index 02537e4f6..b7c93d14f 100644 --- a/internal/shard/shard_thread.go +++ b/internal/shard/shard_thread.go @@ -37,23 +37,24 @@ type ShardError struct { Error error // Error is the error that occurred } -// IOChannels holds the communication channels for an io-thread. -// It contains both the common response channel and the preprocessing response channel. -type IOChannels struct { - CommonResponseChan chan *ops.StoreResponse // CommonResponseChan is used to send standard responses for io-thread operations. +// CmdHandlerChannels holds the communication channels for a Command Handler. +// It contains both the response channel and the preprocessing response channel. +type CmdHandlerChannels struct { + ResponseChan chan *ops.StoreResponse // ResponseChan is used to send standard responses for Command Handler operations. PreProcessingResponseChan chan *ops.StoreResponse // PreProcessingResponseChan is used to send responses related to preprocessing operations. } type ShardThread struct { - id ShardID // id is the unique identifier for the shard. - store *dstore.Store // store that the shard is responsible for. - ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests. - ioThreadMap map[string]IOChannels // ioThreadMap maps each io-thread id to its corresponding IOChannels, containing both the common and preprocessing response channels. - mu sync.RWMutex // mu is the ioThreadMap's mutex for thread safety. - globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors. - shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors. - lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks. - cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks. + id ShardID // id is the unique identifier for the shard. + store *dstore.Store // store that the shard is responsible for. + ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests. + // cmdHandlerMap maps each command handler id to its corresponding CommandHandlerChannels, containing both the common and preprocessing response channels. + cmdHandlerMap map[string]CmdHandlerChannels + mu sync.RWMutex // mu is the cmdHandlerMap's mutex for thread safety. + globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors. + shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors. + lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks. + cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks. } // NewShardThread creates a new ShardThread instance with the given shard id and error channel. @@ -63,7 +64,7 @@ func NewShardThread(id ShardID, gec chan error, sec chan *ShardError, id: id, store: dstore.NewStore(cmdWatchChan, evictionStrategy), ReqChan: make(chan *ops.StoreOp, 1000), - ioThreadMap: make(map[string]IOChannels), + cmdHandlerMap: make(map[string]CmdHandlerChannels), globalErrorChan: gec, shardErrorChan: sec, lastCronExecTime: utils.GetCurrentTime(), @@ -95,30 +96,30 @@ func (shard *ShardThread) runCronTasks() { shard.lastCronExecTime = utils.GetCurrentTime() } -func (shard *ShardThread) registerIOThread(id string, responseChan, preprocessingChan chan *ops.StoreResponse) { +func (shard *ShardThread) registerCommandHandler(id string, responseChan, preprocessingChan chan *ops.StoreResponse) { shard.mu.Lock() - shard.ioThreadMap[id] = IOChannels{ - CommonResponseChan: responseChan, + shard.cmdHandlerMap[id] = CmdHandlerChannels{ + ResponseChan: responseChan, PreProcessingResponseChan: preprocessingChan, } shard.mu.Unlock() } -func (shard *ShardThread) unregisterIOThread(id string) { +func (shard *ShardThread) unregisterCommandHandler(id string) { shard.mu.Lock() - delete(shard.ioThreadMap, id) + delete(shard.cmdHandlerMap, id) shard.mu.Unlock() } // processRequest processes a Store operation for the shard. func (shard *ShardThread) processRequest(op *ops.StoreOp) { shard.mu.RLock() - ioChannels, ok := shard.ioThreadMap[op.IOThreadID] + channels, ok := shard.cmdHandlerMap[op.CmdHandlerID] shard.mu.RUnlock() - ioThreadChan := ioChannels.CommonResponseChan - preProcessChan := ioChannels.PreProcessingResponseChan + cmdHandlerChan := channels.ResponseChan + preProcessChan := channels.PreProcessingResponseChan sp := &ops.StoreResponse{ RequestID: op.RequestID, @@ -140,11 +141,11 @@ func (shard *ShardThread) processRequest(op *ops.StoreOp) { } else { shard.shardErrorChan <- &ShardError{ ShardID: shard.id, - Error: fmt.Errorf(diceerrors.IOThreadNotFoundErr, op.IOThreadID), + Error: fmt.Errorf(diceerrors.CmdHandlerNotFoundErr, op.CmdHandlerID), } } - ioThreadChan <- sp + cmdHandlerChan <- sp } // cleanup handles cleanup logic when the shard stops. diff --git a/main.go b/main.go index e24f76a99..01ea98558 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,7 @@ import ( "github.com/dicedb/dice/internal/server/httpws" "github.com/dicedb/dice/internal/cli" + "github.com/dicedb/dice/internal/commandhandler" "github.com/dicedb/dice/internal/logger" "github.com/dicedb/dice/internal/server/abstractserver" "github.com/dicedb/dice/internal/wal" @@ -145,8 +146,10 @@ func main() { } defer stopProfiling() } - ioThreadManager := iothread.NewManager(config.DiceConfig.Performance.MaxClients, shardManager) - respServer := resp.NewServer(shardManager, ioThreadManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh, wl) + ioThreadManager := iothread.NewManager(config.DiceConfig.Performance.MaxClients) + cmdHandlerManager := commandhandler.NewRegistry(config.DiceConfig.Performance.MaxClients, shardManager) + + respServer := resp.NewServer(shardManager, ioThreadManager, cmdHandlerManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh, wl) serverWg.Add(1) go runServer(ctx, &serverWg, respServer, serverErrCh)