diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go
index c83aa0143..dffb6d216 100644
--- a/integration_tests/commands/resp/setup.go
+++ b/integration_tests/commands/resp/setup.go
@@ -29,6 +29,7 @@ import (
"testing"
"time"
+ "github.com/dicedb/dice/internal/commandhandler"
"github.com/dicedb/dice/internal/iothread"
"github.com/dicedb/dice/internal/server/resp"
"github.com/dicedb/dice/internal/wal"
@@ -211,10 +212,12 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) {
cmdWatchSubscriptionChan := make(chan watchmanager.WatchSubscription)
gec := make(chan error)
shardManager := shard.NewShardManager(1, cmdWatchChan, gec)
- ioThreadManager := iothread.NewManager(20000, shardManager)
+ ioThreadManager := iothread.NewManager(20000)
+ cmdHandlerManager := commandhandler.NewRegistry(20000, shardManager)
+
// Initialize the RESP Server
wl, _ := wal.NewNullWAL()
- testServer := resp.NewServer(shardManager, ioThreadManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, wl)
+ testServer := resp.NewServer(shardManager, ioThreadManager, cmdHandlerManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, wl)
ctx, cancel := context.WithCancel(context.Background())
fmt.Println("Starting the test server on port", config.DiceConfig.RespServer.Port)
diff --git a/internal/iothread/cmd_compose.go b/internal/commandhandler/cmd_compose.go
similarity index 98%
rename from internal/iothread/cmd_compose.go
rename to internal/commandhandler/cmd_compose.go
index 081871f47..6ea026550 100644
--- a/internal/iothread/cmd_compose.go
+++ b/internal/commandhandler/cmd_compose.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-package iothread
+package commandhandler
import (
"math"
@@ -24,7 +24,7 @@ import (
"github.com/dicedb/dice/internal/ops"
)
-// This file contains functions used by the IOThread to handle and process responses
+// This file contains functions used by the CommandHandler to handle and process responses
// from multiple shards during distributed operations. For commands that are executed
// across several shards, such as MultiShard commands, dedicated functions are responsible
// for aggregating and managing the results.
diff --git a/internal/iothread/cmd_custom.go b/internal/commandhandler/cmd_custom.go
similarity index 95%
rename from internal/iothread/cmd_custom.go
rename to internal/commandhandler/cmd_custom.go
index 5dd568a92..2fcc88d5a 100644
--- a/internal/iothread/cmd_custom.go
+++ b/internal/commandhandler/cmd_custom.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-package iothread
+package commandhandler
import (
"fmt"
@@ -28,7 +28,7 @@ import (
// RespAuth returns with an encoded "OK" if the user is authenticated
// If the user is not authenticated, it returns with an encoded error message
-func (t *BaseIOThread) RespAuth(args []string) interface{} {
+func (h *BaseCommandHandler) RespAuth(args []string) interface{} {
// Check for incorrect number of arguments (arity error).
if len(args) < 1 || len(args) > 2 {
return diceerrors.ErrWrongArgumentCount("AUTH")
@@ -47,7 +47,7 @@ func (t *BaseIOThread) RespAuth(args []string) interface{} {
username, password = args[0], args[1]
}
- if err := t.Session.Validate(username, password); err != nil {
+ if err := h.Session.Validate(username, password); err != nil {
return err
}
diff --git a/internal/iothread/cmd_decompose.go b/internal/commandhandler/cmd_decompose.go
similarity index 80%
rename from internal/iothread/cmd_decompose.go
rename to internal/commandhandler/cmd_decompose.go
index 0ba439aa3..b9e881153 100644
--- a/internal/iothread/cmd_decompose.go
+++ b/internal/commandhandler/cmd_decompose.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-package iothread
+package commandhandler
import (
"context"
@@ -28,7 +28,7 @@ import (
"github.com/dicedb/dice/internal/store"
)
-// This file is utilized by the IOThread to decompose commands that need to be executed
+// This file is utilized by the CommandHandler to decompose commands that need to be executed
// across multiple shards. For commands that operate on multiple keys or necessitate
// distribution across shards (e.g., MultiShard commands), a Breakup function is invoked
// to transform the original command into multiple smaller commands, each directed at
@@ -41,13 +41,13 @@ import (
// decomposeRename breaks down the RENAME command into separate DELETE and SET commands.
// It first waits for the result of a GET command from shards. If successful, it removes
// the old key using a DEL command and sets the new key with the retrieved value using a SET command.
-func decomposeRename(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeRename(ctx context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
// Waiting for GET command response
var val string
select {
case <-ctx.Done():
- slog.Error("IOThread timed out waiting for response from shards", slog.String("id", thread.id), slog.Any("error", ctx.Err()))
- case preProcessedResp, ok := <-thread.preprocessingChan:
+ slog.Error("CommandHandler timed out waiting for response from shards", slog.String("id", h.id), slog.Any("error", ctx.Err()))
+ case preProcessedResp, ok := <-h.preprocessingChan:
if ok {
evalResp := preProcessedResp.EvalResponse
if evalResp.Error != nil {
@@ -85,13 +85,13 @@ func decomposeRename(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCm
// decomposeCopy breaks down the COPY command into a SET command that copies a value from
// one key to another. It first retrieves the value of the original key from shards, then
// sets the value to the destination key using a SET command.
-func decomposeCopy(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeCopy(ctx context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
// Waiting for GET command response
var resp *ops.StoreResponse
select {
case <-ctx.Done():
- slog.Error("IOThread timed out waiting for response from shards", slog.String("id", thread.id), slog.Any("error", ctx.Err()))
- case preProcessedResp, ok := <-thread.preprocessingChan:
+ slog.Error("CommandHandler timed out waiting for response from shards", slog.String("id", h.id), slog.Any("error", ctx.Err()))
+ case preProcessedResp, ok := <-h.preprocessingChan:
if ok {
resp = preProcessedResp
}
@@ -124,7 +124,7 @@ func decomposeCopy(ctx context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd)
// decomposeMSet decomposes the MSET (Multi-set) command into individual SET commands.
// It expects an even number of arguments (key-value pairs). For each pair, it creates
// a separate SET command to store the value at the given key.
-func decomposeMSet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeMSet(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args)%2 != 0 {
return nil, diceerrors.ErrWrongArgumentCount("MSET")
}
@@ -148,7 +148,7 @@ func decomposeMSet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cm
// decomposeMGet decomposes the MGET (Multi-get) command into individual GET commands.
// It expects a list of keys, and for each key, it creates a separate GET command to
// retrieve the value associated with that key.
-func decomposeMGet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeMGet(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) < 1 {
return nil, diceerrors.ErrWrongArgumentCount("MGET")
}
@@ -164,7 +164,7 @@ func decomposeMGet(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cm
return decomposedCmds, nil
}
-func decomposeSInter(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeSInter(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) < 1 {
return nil, diceerrors.ErrWrongArgumentCount("SINTER")
}
@@ -180,7 +180,7 @@ func decomposeSInter(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*
return decomposedCmds, nil
}
-func decomposeSDiff(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeSDiff(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) < 1 {
return nil, diceerrors.ErrWrongArgumentCount("SDIFF")
}
@@ -196,7 +196,7 @@ func decomposeSDiff(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*c
return decomposedCmds, nil
}
-func decomposeJSONMget(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeJSONMget(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) < 2 {
return nil, diceerrors.ErrWrongArgumentCount("JSON.MGET")
}
@@ -215,7 +215,7 @@ func decomposeJSONMget(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([
return decomposedCmds, nil
}
-func decomposeTouch(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeTouch(_ context.Context, _ *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) == 0 {
return nil, diceerrors.ErrWrongArgumentCount("TOUCH")
}
@@ -232,13 +232,13 @@ func decomposeTouch(_ context.Context, _ *BaseIOThread, cd *cmd.DiceDBCmd) ([]*c
return decomposedCmds, nil
}
-func decomposeDBSize(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeDBSize(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) > 0 {
return nil, diceerrors.ErrWrongArgumentCount("DBSIZE")
}
decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args))
- for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ {
+ for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ {
decomposedCmds = append(decomposedCmds,
&cmd.DiceDBCmd{
Cmd: store.SingleShardSize,
@@ -249,13 +249,13 @@ func decomposeDBSize(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd)
return decomposedCmds, nil
}
-func decomposeKeys(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeKeys(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) != 1 {
return nil, diceerrors.ErrWrongArgumentCount("KEYS")
}
decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args))
- for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ {
+ for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ {
decomposedCmds = append(decomposedCmds,
&cmd.DiceDBCmd{
Cmd: store.SingleShardKeys,
@@ -266,13 +266,13 @@ func decomposeKeys(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) (
return decomposedCmds, nil
}
-func decomposeFlushDB(_ context.Context, thread *BaseIOThread, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
+func decomposeFlushDB(_ context.Context, h *BaseCommandHandler, cd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error) {
if len(cd.Args) > 1 {
return nil, diceerrors.ErrWrongArgumentCount("FLUSHDB")
}
decomposedCmds := make([]*cmd.DiceDBCmd, 0, len(cd.Args))
- for i := uint8(0); i < uint8(thread.shardManager.GetShardCount()); i++ {
+ for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ {
decomposedCmds = append(decomposedCmds,
&cmd.DiceDBCmd{
Cmd: store.FlushDB,
diff --git a/internal/iothread/cmd_meta.go b/internal/commandhandler/cmd_meta.go
similarity index 96%
rename from internal/iothread/cmd_meta.go
rename to internal/commandhandler/cmd_meta.go
index 837a5be4b..2dd732b1b 100644
--- a/internal/iothread/cmd_meta.go
+++ b/internal/commandhandler/cmd_meta.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-package iothread
+package commandhandler
import (
"context"
@@ -214,12 +214,12 @@ const (
type CmdMeta struct {
CmdType
- Cmd string
- IOThreadHandler func([]string) []byte
+ Cmd string
+ CmdHandlerFunction func([]string) []byte
// decomposeCommand is a function that takes a DiceDB command and breaks it down into smaller,
// manageable DiceDB commands for each shard processing. It returns a slice of DiceDB commands.
- decomposeCommand func(ctx context.Context, thread *BaseIOThread, DiceDBCmd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error)
+ decomposeCommand func(ctx context.Context, h *BaseCommandHandler, DiceDBCmd *cmd.DiceDBCmd) ([]*cmd.DiceDBCmd, error)
// composeResponse is a function that combines multiple responses from the execution of commands
// into a single response object. It accepts a variadic parameter of EvalResponse objects
@@ -234,10 +234,10 @@ type CmdMeta struct {
// preProcessResponse is a function that handles the preprocessing of a DiceDB command by
// preparing the necessary operations (e.g., fetching values from shards) before the command
- // is executed. It takes the io-thread and the original DiceDB command as parameters and
+ // is executed. It takes the CommandHandler and the original DiceDB command as parameters and
// ensures that any required information is retrieved and processed in advance. Use this when set
// preProcessingReq = true.
- preProcessResponse func(thread *BaseIOThread, DiceDBCmd *cmd.DiceDBCmd) error
+ preProcessResponse func(h *BaseCommandHandler, DiceDBCmd *cmd.DiceDBCmd) error
}
var CommandsMeta = map[string]CmdMeta{
@@ -695,8 +695,8 @@ func init() {
func validateCmdMeta(c string, meta CmdMeta) error {
switch meta.CmdType {
case Global:
- if meta.IOThreadHandler == nil {
- return fmt.Errorf("global command %s must have IOThreadHandler function", c)
+ if meta.CmdHandlerFunction == nil {
+ return fmt.Errorf("global command %s must have CmdHandlerFunction function", c)
}
case MultiShard, AllShard:
if meta.decomposeCommand == nil || meta.composeResponse == nil {
diff --git a/internal/iothread/cmd_preprocess.go b/internal/commandhandler/cmd_preprocess.go
similarity index 87%
rename from internal/iothread/cmd_preprocess.go
rename to internal/commandhandler/cmd_preprocess.go
index b5116622f..c985997b7 100644
--- a/internal/iothread/cmd_preprocess.go
+++ b/internal/commandhandler/cmd_preprocess.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-package iothread
+package commandhandler
import (
"github.com/dicedb/dice/internal/cmd"
@@ -25,13 +25,13 @@ import (
// preProcessRename prepares the RENAME command for preprocessing by sending a GET command
// to retrieve the value of the original key. The retrieved value is used later in the
// decomposeRename function to delete the old key and set the new key.
-func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error {
+func preProcessRename(h *BaseCommandHandler, diceDBCmd *cmd.DiceDBCmd) error {
if len(diceDBCmd.Args) < 2 {
return diceerrors.ErrWrongArgumentCount("RENAME")
}
key := diceDBCmd.Args[0]
- sid, rc := thread.shardManager.GetShardInfo(key)
+ sid, rc := h.shardManager.GetShardInfo(key)
preCmd := cmd.DiceDBCmd{
Cmd: "RENAME",
@@ -42,7 +42,7 @@ func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error {
SeqID: 0,
RequestID: GenerateUniqueRequestID(),
Cmd: &preCmd,
- IOThreadID: thread.id,
+ CmdHandlerID: h.id,
ShardID: sid,
Client: nil,
PreProcessing: true,
@@ -54,12 +54,12 @@ func preProcessRename(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error {
// preProcessCopy prepares the COPY command for preprocessing by sending a GET command
// to retrieve the value of the original key. The retrieved value is used later in the
// decomposeCopy function to copy the value to the destination key.
-func customProcessCopy(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error {
+func customProcessCopy(h *BaseCommandHandler, diceDBCmd *cmd.DiceDBCmd) error {
if len(diceDBCmd.Args) < 2 {
return diceerrors.ErrWrongArgumentCount("COPY")
}
- sid, rc := thread.shardManager.GetShardInfo(diceDBCmd.Args[0])
+ sid, rc := h.shardManager.GetShardInfo(diceDBCmd.Args[0])
preCmd := cmd.DiceDBCmd{
Cmd: "COPY",
@@ -71,7 +71,7 @@ func customProcessCopy(thread *BaseIOThread, diceDBCmd *cmd.DiceDBCmd) error {
SeqID: 0,
RequestID: GenerateUniqueRequestID(),
Cmd: &preCmd,
- IOThreadID: thread.id,
+ CmdHandlerID: h.id,
ShardID: sid,
Client: nil,
PreProcessing: true,
diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go
new file mode 100644
index 000000000..d9c719498
--- /dev/null
+++ b/internal/commandhandler/commandhandler.go
@@ -0,0 +1,519 @@
+package commandhandler
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net"
+ "strconv"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "github.com/dicedb/dice/config"
+ "github.com/dicedb/dice/internal/auth"
+ "github.com/dicedb/dice/internal/clientio"
+ "github.com/dicedb/dice/internal/clientio/requestparser"
+ "github.com/dicedb/dice/internal/cmd"
+ diceerrors "github.com/dicedb/dice/internal/errors"
+ "github.com/dicedb/dice/internal/ops"
+ "github.com/dicedb/dice/internal/querymanager"
+ "github.com/dicedb/dice/internal/shard"
+ "github.com/dicedb/dice/internal/wal"
+ "github.com/dicedb/dice/internal/watchmanager"
+ "github.com/google/uuid"
+)
+
+const defaultRequestTimeout = 6 * time.Second
+
+var requestCounter uint32
+
+type CommandHandler interface {
+ ID() string
+ Start(ctx context.Context) error
+ Stop() error
+}
+
+type BaseCommandHandler struct {
+ CommandHandler
+ id string
+ parser requestparser.Parser
+ shardManager *shard.ShardManager
+ adhocReqChan chan *cmd.DiceDBCmd
+ Session *auth.Session
+ globalErrorChan chan error
+ ioThreadReadChan chan []byte // Channel to receive data from io-thread
+ ioThreadWriteChan chan interface{} // Channel to send data to io-thread
+ ioThreadErrChan chan error // Channel to receive errors from io-thread
+ responseChan chan *ops.StoreResponse // Channel to communicate with shard
+ preprocessingChan chan *ops.StoreResponse // Channel to communicate with shard
+ cmdWatchSubscriptionChan chan watchmanager.WatchSubscription
+ wl wal.AbstractWAL
+}
+
+func NewCommandHandler(id string, responseChan, preprocessingChan chan *ops.StoreResponse,
+ cmdWatchSubscriptionChan chan watchmanager.WatchSubscription,
+ parser requestparser.Parser, shardManager *shard.ShardManager, gec chan error,
+ ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, ioThreadErrChan chan error,
+ wl wal.AbstractWAL) *BaseCommandHandler {
+ return &BaseCommandHandler{
+ id: id,
+ parser: parser,
+ shardManager: shardManager,
+ adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize),
+ Session: auth.NewSession(),
+ globalErrorChan: gec,
+ ioThreadReadChan: ioThreadReadChan,
+ ioThreadWriteChan: ioThreadWriteChan,
+ ioThreadErrChan: ioThreadErrChan,
+ responseChan: responseChan,
+ preprocessingChan: preprocessingChan,
+ cmdWatchSubscriptionChan: cmdWatchSubscriptionChan,
+ wl: wl,
+ }
+}
+
+func (h *BaseCommandHandler) ID() string {
+ return h.id
+}
+
+func (h *BaseCommandHandler) Start(ctx context.Context) error {
+ errChan := make(chan error, 1) // for adhoc request processing errors
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-h.ioThreadErrChan:
+ return err
+ case cmdReq := <-h.adhocReqChan:
+ resp, err := h.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout)
+ h.sendResponseToIOThread(resp, err)
+ case err := <-errChan:
+ return h.handleError(err)
+ case data := <-h.ioThreadReadChan:
+ resp, err := h.processCommand(ctx, &data, h.globalErrorChan)
+ h.sendResponseToIOThread(resp, err)
+ }
+ }
+}
+
+// processCommand processes commands recevied from io thread
+func (h *BaseCommandHandler) processCommand(ctx context.Context, data *[]byte, gec chan error) (interface{}, error) {
+ commands, err := h.parser.Parse(*data)
+
+ if err != nil {
+ slog.Debug("error parsing commands from io thread", slog.String("id", h.id), slog.Any("error", err))
+ return nil, err
+ }
+
+ if len(commands) == 0 {
+ slog.Debug("invalid request from io thread with zero length", slog.String("id", h.id))
+ return nil, fmt.Errorf("ERR: Invalid request")
+ }
+
+ // DiceDB supports clients to send only one request at a time
+ // We also need to ensure that the client is blocked until the response is received
+ if len(commands) > 1 {
+ return nil, fmt.Errorf("ERR: Multiple commands not supported")
+ }
+
+ err = h.isAuthenticated(commands[0])
+ if err != nil {
+ slog.Debug("command handler authentication failed", slog.String("id", h.id), slog.Any("error", err))
+ return nil, err
+ }
+
+ return h.handleCmdRequestWithTimeout(ctx, gec, commands, false, defaultRequestTimeout)
+}
+
+func (h *BaseCommandHandler) handleCmdRequestWithTimeout(ctx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) (interface{}, error) {
+ execCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ return h.executeCommandHandler(execCtx, gec, commands, isWatchNotification)
+}
+
+func (h *BaseCommandHandler) executeCommandHandler(execCtx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) {
+ // Retrieve metadata for the command to determine if multisharding is supported.
+ meta, ok := CommandsMeta[commands[0].Cmd]
+ if ok && meta.preProcessing {
+ if err := meta.preProcessResponse(h, commands[0]); err != nil {
+ slog.Debug("error pre processing response", slog.String("id", h.id), slog.Any("error", err))
+ return nil, err
+ }
+ }
+
+ resp, err := h.executeCommand(execCtx, commands[0], isWatchNotification)
+
+ // log error and send to global error channel if it's a connection error
+ if err != nil {
+ slog.Error("Error executing command", slog.String("id", h.id), slog.Any("error", err))
+ if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) {
+ slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err))
+ gec <- err
+ }
+ }
+
+ return resp, err
+}
+
+func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) {
+ // Break down the single command into multiple commands if multisharding is supported.
+ // The length of cmdList helps determine how many shards to wait for responses.
+ cmdList := make([]*cmd.DiceDBCmd, 0)
+ var watchLabel string
+
+ // Retrieve metadata for the command to determine if multisharding is supported.
+ meta, ok := CommandsMeta[diceDBCmd.Cmd]
+ if !ok {
+ // If no metadata exists, treat it as a single command and not migrated
+ cmdList = append(cmdList, diceDBCmd)
+ } else {
+ // Depending on the command type, decide how to handle it.
+ switch meta.CmdType {
+ case Global:
+ // process global command immediately without involving any shards.
+ return meta.CmdHandlerFunction(diceDBCmd.Args), nil
+
+ case SingleShard:
+ // For single-shard or custom commands, process them without breaking up.
+ cmdList = append(cmdList, diceDBCmd)
+
+ case MultiShard, AllShard:
+ var err error
+ // If the command supports multisharding, break it down into multiple commands.
+ cmdList, err = meta.decomposeCommand(ctx, h, diceDBCmd)
+ if err != nil {
+ slog.Debug("error decomposing command", slog.String("id", h.id), slog.Any("error", err))
+ // Check if it's a CustomError
+ var customErr *diceerrors.PreProcessError
+ if errors.As(err, &customErr) {
+ return nil, err
+ }
+ return nil, err
+ }
+
+ case Custom:
+ return h.handleCustomCommands(diceDBCmd)
+
+ case Watch:
+ // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass
+ // it along as is.
+ // Modify the command name to remove the .WATCH suffix, this will allow us to generate a consistent
+ // fingerprint (which uses the command name without the suffix)
+ diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6]
+
+ // check if the last argument is a watch label
+ label := diceDBCmd.Args[len(diceDBCmd.Args)-1]
+ if _, err := uuid.Parse(label); err == nil {
+ watchLabel = label
+
+ // remove the watch label from the args
+ diceDBCmd.Args = diceDBCmd.Args[:len(diceDBCmd.Args)-1]
+ }
+
+ watchCmd := &cmd.DiceDBCmd{
+ Cmd: diceDBCmd.Cmd,
+ Args: diceDBCmd.Args,
+ }
+ cmdList = append(cmdList, watchCmd)
+ isWatchNotification = true
+
+ case Unwatch:
+ // Generate the Cmd being unwatched. All we need to do is remove the .UNWATCH suffix from the command and pass
+ // it along as is.
+ // Modify the command name to remove the .UNWATCH suffix, this will allow us to generate a consistent
+ // fingerprint (which uses the command name without the suffix)
+ diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-8]
+ watchCmd := &cmd.DiceDBCmd{
+ Cmd: diceDBCmd.Cmd,
+ Args: diceDBCmd.Args,
+ }
+ cmdList = append(cmdList, watchCmd)
+ isWatchNotification = false
+ }
+ }
+
+ // Unsubscribe Unwatch command type
+ if meta.CmdType == Unwatch {
+ return h.handleCommandUnwatch(cmdList)
+ }
+
+ // Scatter the broken-down commands to the appropriate shards.
+ if err := h.scatter(ctx, cmdList, meta.CmdType); err != nil {
+ return nil, err
+ }
+
+ // Gather the responses from the shards and write them to the buffer.
+ resp, err := h.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel)
+ if err != nil {
+ return nil, err
+ }
+
+ // Proceed to subscribe after successful execution
+ if meta.CmdType == Watch {
+ h.handleCommandWatch(cmdList)
+ }
+
+ return resp, nil
+}
+
+func (h *BaseCommandHandler) handleCustomCommands(diceDBCmd *cmd.DiceDBCmd) (interface{}, error) {
+ // if command is of type Custom, write a custom logic around it
+ switch diceDBCmd.Cmd {
+ case CmdAuth:
+ return h.RespAuth(diceDBCmd.Args), nil
+ case CmdEcho:
+ return RespEcho(diceDBCmd.Args), nil
+ case CmdAbort:
+ slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", h.id))
+ h.globalErrorChan <- diceerrors.ErrAborted
+ return clientio.OK, nil
+ case CmdPing:
+ return RespPING(diceDBCmd.Args), nil
+ case CmdHello:
+ return RespHello(diceDBCmd.Args), nil
+ case CmdSleep:
+ return RespSleep(diceDBCmd.Args), nil
+ default:
+ return nil, diceerrors.ErrUnknownCmd(diceDBCmd.Cmd)
+ }
+}
+
+// handleCommandWatch sends a watch subscription request to the watch manager.
+func (h *BaseCommandHandler) handleCommandWatch(cmdList []*cmd.DiceDBCmd) {
+ h.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
+ Subscribe: true,
+ WatchCmd: cmdList[len(cmdList)-1],
+ AdhocReqChan: h.adhocReqChan,
+ }
+}
+
+// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client.
+// The response is sent before the unwatch request is processed by the watch manager.
+func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) (interface{}, error) {
+ // extract the fingerprint
+ command := cmdList[len(cmdList)-1]
+ fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32)
+ if parseErr != nil {
+ return nil, diceerrors.ErrInvalidFingerprint
+ }
+
+ // send the unsubscribe request
+ h.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
+ Subscribe: false,
+ AdhocReqChan: h.adhocReqChan,
+ Fingerprint: uint32(fp),
+ }
+
+ return clientio.OK, nil
+}
+
+// scatter distributes the DiceDB commands to the respective shards based on the key.
+// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing.
+func (h *BaseCommandHandler) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd, cmdType CmdType) error {
+ // Otherwise check for the shard based on the key using hash
+ // and send it to the particular shard
+ // Check if the context has been canceled or expired.
+ select {
+ case <-ctx.Done():
+ // If the context is canceled, return the error associated with it.
+ return ctx.Err()
+ default:
+ // Proceed with the default case when the context is not canceled.
+
+ if cmdType == AllShard {
+ // If the command type is for all shards, iterate over all available shards.
+ for i := uint8(0); i < uint8(h.shardManager.GetShardCount()); i++ {
+ // Get the shard ID (i) and its associated request channel.
+ shardID, responseChan := i, h.shardManager.GetShard(i).ReqChan
+
+ // Send a StoreOp operation to the shard's request channel.
+ responseChan <- &ops.StoreOp{
+ SeqID: i, // Sequence ID for this operation.
+ RequestID: GenerateUniqueRequestID(), // Unique identifier for the request.
+ Cmd: cmds[0], // Command to be executed, using the first command in cmds.
+ CmdHandlerID: h.id, // ID of the current command handler.
+ ShardID: shardID, // ID of the shard handling this operation.
+ Client: nil, // Client information (if applicable).
+ }
+ }
+ } else {
+ // If the command type is specific to certain commands, process them individually.
+ for i := uint8(0); i < uint8(len(cmds)); i++ {
+ // Determine the appropriate shard for the current command using a routing key.
+ shardID, responseChan := h.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i]))
+
+ // Send a StoreOp operation to the shard's request channel.
+ responseChan <- &ops.StoreOp{
+ SeqID: i, // Sequence ID for this operation.
+ RequestID: GenerateUniqueRequestID(), // Unique identifier for the request.
+ Cmd: cmds[i], // Command to be executed, using the current command in cmds.
+ CmdHandlerID: h.id, // ID of the current command handler.
+ ShardID: shardID, // ID of the shard handling this operation.
+ Client: nil, // Client information (if applicable).
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+// getRoutingKeyFromCommand determines the key used for shard routing
+func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string {
+ if len(diceDBCmd.Args) > 0 {
+ return diceDBCmd.Args[0]
+ }
+ return diceDBCmd.Cmd
+}
+
+// gather collects the responses from multiple shards and writes the results into the provided buffer.
+// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard).
+func (h *BaseCommandHandler) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) (interface{}, error) {
+ // Collect responses from all shards
+ storeOp, err := h.gatherResponses(ctx, numCmds)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(storeOp) == 0 {
+ slog.Error("No response from shards",
+ slog.String("id", h.id),
+ slog.String("command", diceDBCmd.Cmd))
+ return nil, fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd)
+ }
+
+ if isWatchNotification {
+ return h.handleWatchNotification(diceDBCmd, storeOp[0], watchLabel)
+ }
+
+ // Process command based on its type
+ cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd]
+ if !ok {
+ return h.handleUnsupportedCommand(storeOp[0])
+ }
+
+ return h.handleCommand(cmdMeta, diceDBCmd, storeOp)
+}
+
+// gatherResponses collects responses from all shards
+func (h *BaseCommandHandler) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) {
+ storeOp := make([]ops.StoreResponse, 0, numCmds)
+
+ for numCmds > 0 {
+ select {
+ case <-ctx.Done():
+ slog.Error("Timed out waiting for response from shards",
+ slog.String("id", h.id),
+ slog.Any("error", ctx.Err()))
+ return nil, ctx.Err()
+
+ case resp, ok := <-h.responseChan:
+ if ok {
+ storeOp = append(storeOp, *resp)
+ }
+ numCmds--
+
+ case sError, ok := <-h.shardManager.ShardErrorChan:
+ if ok {
+ slog.Error("Error from shard",
+ slog.String("id", h.id),
+ slog.Any("error", sError))
+ return nil, sError.Error
+ }
+ }
+ }
+
+ return storeOp, nil
+}
+
+// handleWatchNotification processes watch notification responses
+func (h *BaseCommandHandler) handleWatchNotification(diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) (interface{}, error) {
+ fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint())
+
+ // if watch label is not empty, then this is the first response for the watch command
+ // hence, we will send the watch label as part of the response
+ firstRespElem := diceDBCmd.Cmd
+ if watchLabel != "" {
+ firstRespElem = watchLabel
+ }
+
+ if resp.EvalResponse.Error != nil {
+ // This is a special case where error is returned as part of the watch response
+ return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error), nil
+ }
+
+ return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result), nil
+}
+
+// handleUnsupportedCommand processes commands not in CommandsMeta
+func (h *BaseCommandHandler) handleUnsupportedCommand(resp ops.StoreResponse) (interface{}, error) {
+ if resp.EvalResponse.Error != nil {
+ return nil, resp.EvalResponse.Error
+ }
+ return resp.EvalResponse.Result, nil
+}
+
+// handleCommand processes commands based on their type
+func (h *BaseCommandHandler) handleCommand(cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) (interface{}, error) {
+ switch cmdMeta.CmdType {
+ case SingleShard, Custom:
+ if storeOp[0].EvalResponse.Error != nil {
+ return nil, storeOp[0].EvalResponse.Error
+ } else {
+ return storeOp[0].EvalResponse.Result, nil
+ }
+
+ case MultiShard, AllShard:
+ return cmdMeta.composeResponse(storeOp...), nil
+
+ default:
+ slog.Error("Unknown command type",
+ slog.String("id", h.id),
+ slog.String("command", diceDBCmd.Cmd),
+ slog.Any("evalResp", storeOp))
+ return nil, diceerrors.ErrInternalServer
+ }
+}
+
+func (h *BaseCommandHandler) handleError(err error) error {
+ if err != nil {
+ if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) {
+ slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err))
+ return err
+ }
+ }
+ return fmt.Errorf("error writing response: %v", err)
+}
+
+func (h *BaseCommandHandler) sendResponseToIOThread(resp interface{}, err error) {
+ if err != nil {
+ var customErr *diceerrors.PreProcessError
+ if errors.As(err, &customErr) {
+ h.ioThreadWriteChan <- customErr.Result
+ }
+ h.ioThreadWriteChan <- err
+ return
+ }
+ h.ioThreadWriteChan <- resp
+}
+
+func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error {
+ if diceDBCmd.Cmd != auth.Cmd && !h.Session.IsActive() {
+ return errors.New("NOAUTH Authentication required")
+ }
+
+ return nil
+}
+
+func (h *BaseCommandHandler) Stop() error {
+ slog.Info("Stopping command handler", slog.String("id", h.id))
+ h.Session.Expire()
+ return nil
+}
+
+func GenerateUniqueRequestID() uint32 {
+ return atomic.AddUint32(&requestCounter, 1)
+}
diff --git a/internal/commandhandler/registry.go b/internal/commandhandler/registry.go
new file mode 100644
index 000000000..1ce93eb31
--- /dev/null
+++ b/internal/commandhandler/registry.go
@@ -0,0 +1,74 @@
+package commandhandler
+
+import (
+ "errors"
+ "sync"
+ "sync/atomic"
+
+ "github.com/dicedb/dice/internal/shard"
+)
+
+type Registry struct {
+ activeCmdHandlers sync.Map
+ numCmdHandlers atomic.Int32
+ maxClients int32
+ ShardManager *shard.ShardManager
+ mu sync.Mutex
+}
+
+var (
+ ErrMaxCmdHandlersReached = errors.New("maximum number of command handlers reached")
+ ErrCmdHandlerNotFound = errors.New("command handler not found")
+ ErrCmdHandlerNotBase = errors.New("command handler is not a BaseCommandHandler")
+)
+
+func NewRegistry(maxClients int32, sm *shard.ShardManager) *Registry {
+ return &Registry{
+ maxClients: maxClients,
+ ShardManager: sm,
+ }
+}
+
+func (m *Registry) RegisterCommandHandler(cmdHandler *BaseCommandHandler) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.CommandHandlerCount() >= m.maxClients {
+ return ErrMaxCmdHandlersReached
+ }
+
+ responseChan := cmdHandler.responseChan
+ preprocessingChan := cmdHandler.preprocessingChan
+
+ if responseChan != nil && preprocessingChan != nil {
+ m.ShardManager.RegisterCommandHandler(cmdHandler.ID(), responseChan, preprocessingChan) // TODO: Change responseChan type to ShardResponse
+ } else if responseChan != nil && preprocessingChan == nil {
+ m.ShardManager.RegisterCommandHandler(cmdHandler.ID(), responseChan, nil)
+ }
+
+ m.activeCmdHandlers.Store(cmdHandler.ID(), cmdHandler)
+
+ m.numCmdHandlers.Add(1)
+ return nil
+}
+
+func (m *Registry) CommandHandlerCount() int32 {
+ return m.numCmdHandlers.Load()
+}
+
+func (m *Registry) UnregisterCommandHandler(id string) error {
+ m.ShardManager.UnregisterCommandHandler(id)
+ if cmdHandler, loaded := m.activeCmdHandlers.LoadAndDelete(id); loaded {
+ ch, ok := cmdHandler.(*BaseCommandHandler)
+ if !ok {
+ return ErrCmdHandlerNotBase
+ }
+ if err := ch.Stop(); err != nil {
+ return err
+ }
+ } else {
+ return ErrCmdHandlerNotFound
+ }
+ m.numCmdHandlers.Add(-1)
+ return nil
+}
diff --git a/internal/errors/errors.go b/internal/errors/errors.go
index 3d5b28e0a..04d7b1e3c 100644
--- a/internal/errors/errors.go
+++ b/internal/errors/errors.go
@@ -35,7 +35,7 @@ const (
WrongTypeErr = "-WRONGTYPE Operation against a key holding the wrong kind of value"
WrongTypeHllErr = "-WRONGTYPE Key is not a valid HyperLogLog string value."
InvalidHllErr = "-INVALIDOBJ Corrupted HLL object detected"
- IOThreadNotFoundErr = "io-thread with ID %s not found"
+ CmdHandlerNotFoundErr = "command handler with ID %s not found"
JSONPathValueTypeErr = "-WRONGTYPE wrong type of path value - expected string but found %s"
HashValueNotIntegerErr = "hash value is not an integer"
InternalServerError = "-ERR: Internal server error, unable to process command"
diff --git a/internal/iothread/iothread.go b/internal/iothread/iothread.go
index b7ef96ee6..d7c108235 100644
--- a/internal/iothread/iothread.go
+++ b/internal/iothread/iothread.go
@@ -18,35 +18,12 @@ package iothread
import (
"context"
- "errors"
- "fmt"
"log/slog"
- "net"
- "strconv"
- "strings"
- "sync/atomic"
- "syscall"
- "time"
- "github.com/dicedb/dice/config"
"github.com/dicedb/dice/internal/auth"
- "github.com/dicedb/dice/internal/clientio"
"github.com/dicedb/dice/internal/clientio/iohandler"
- "github.com/dicedb/dice/internal/clientio/requestparser"
- "github.com/dicedb/dice/internal/cmd"
- diceerrors "github.com/dicedb/dice/internal/errors"
- "github.com/dicedb/dice/internal/ops"
- "github.com/dicedb/dice/internal/querymanager"
- "github.com/dicedb/dice/internal/shard"
- "github.com/dicedb/dice/internal/wal"
- "github.com/dicedb/dice/internal/watchmanager"
- "github.com/google/uuid"
)
-const defaultRequestTimeout = 6 * time.Second
-
-var requestCounter uint32
-
// IOThread interface
type IOThread interface {
ID() string
@@ -56,35 +33,23 @@ type IOThread interface {
type BaseIOThread struct {
IOThread
- id string
- ioHandler iohandler.IOHandler
- parser requestparser.Parser
- shardManager *shard.ShardManager
- adhocReqChan chan *cmd.DiceDBCmd
- Session *auth.Session
- globalErrorChan chan error
- responseChan chan *ops.StoreResponse
- preprocessingChan chan *ops.StoreResponse
- cmdWatchSubscriptionChan chan watchmanager.WatchSubscription
- wl wal.AbstractWAL
+ id string
+ ioHandler iohandler.IOHandler
+ Session *auth.Session
+ ioThreadReadChan chan []byte // Channel to send data to the command handler
+ ioThreadWriteChan chan interface{} // Channel to receive data from the command handler
+ ioThreadErrChan chan error // Channel to receive errors from the ioHandler
}
-func NewIOThread(wid string, responseChan, preprocessingChan chan *ops.StoreResponse,
- cmdWatchSubscriptionChan chan watchmanager.WatchSubscription,
- ioHandler iohandler.IOHandler, parser requestparser.Parser,
- shardManager *shard.ShardManager, gec chan error, wl wal.AbstractWAL) *BaseIOThread {
+func NewIOThread(id string, ioHandler iohandler.IOHandler,
+ ioThreadReadChan chan []byte, ioThreadWriteChan chan interface{}, ioThreadErrChan chan error) *BaseIOThread {
return &BaseIOThread{
- id: wid,
- ioHandler: ioHandler,
- parser: parser,
- shardManager: shardManager,
- globalErrorChan: gec,
- responseChan: responseChan,
- preprocessingChan: preprocessingChan,
- Session: auth.NewSession(),
- adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize),
- cmdWatchSubscriptionChan: cmdWatchSubscriptionChan,
- wl: wl,
+ id: id,
+ ioHandler: ioHandler,
+ Session: auth.NewSession(),
+ ioThreadReadChan: ioThreadReadChan,
+ ioThreadWriteChan: ioThreadWriteChan,
+ ioThreadErrChan: ioThreadErrChan,
}
}
@@ -93,9 +58,9 @@ func (t *BaseIOThread) ID() string {
}
func (t *BaseIOThread) Start(ctx context.Context) error {
- errChan := make(chan error, 1)
- incomingDataChan := make(chan []byte)
- readErrChan := make(chan error)
+ // local channels to communicate between Start and startInputReader goroutine
+ incomingDataChan := make(chan []byte) // data channel
+ readErrChan := make(chan error) // error channel
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
@@ -111,17 +76,17 @@ func (t *BaseIOThread) Start(ctx context.Context) error {
slog.Warn("Error stopping io-thread:", slog.String("id", t.id), slog.Any("error", err))
}
return ctx.Err()
- case err := <-errChan:
- return t.handleError(err)
- case cmdReq := <-t.adhocReqChan:
- t.handleCmdRequestWithTimeout(ctx, errChan, []*cmd.DiceDBCmd{cmdReq}, true, defaultRequestTimeout)
case data := <-incomingDataChan:
- if err := t.processIncomingData(ctx, &data, errChan); err != nil {
- return err
- }
+ t.ioThreadReadChan <- data
case err := <-readErrChan:
slog.Debug("Read error in io-thread, connection closed possibly", slog.String("id", t.id), slog.Any("error", err))
+ t.ioThreadErrChan <- err
return err
+ case resp := <-t.ioThreadWriteChan:
+ err := t.ioHandler.Write(ctx, resp)
+ if err != nil {
+ slog.Debug("Error sending response to client", slog.String("id", t.id), slog.Any("error", err))
+ }
}
}
}
@@ -149,484 +114,8 @@ func (t *BaseIOThread) startInputReader(ctx context.Context, incomingDataChan ch
}
}
-func (t *BaseIOThread) handleError(err error) error {
- if err != nil {
- if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) {
- slog.Debug("Connection closed for io-thread", slog.String("id", t.id), slog.Any("error", err))
- return err
- }
- }
- return fmt.Errorf("error writing response: %v", err)
-}
-
-func (t *BaseIOThread) processIncomingData(ctx context.Context, data *[]byte, errChan chan error) error {
- commands, err := t.parser.Parse(*data)
-
- if err != nil {
- err = t.ioHandler.Write(ctx, err)
- if err != nil {
- slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err))
- return err
- }
- return nil
- }
-
- if len(commands) == 0 {
- err = t.ioHandler.Write(ctx, fmt.Errorf("ERR: Invalid request"))
- if err != nil {
- slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err))
- return err
- }
- return nil
- }
-
- // DiceDB supports clients to send only one request at a time
- // We also need to ensure that the client is blocked until the response is received
- if len(commands) > 1 {
- err = t.ioHandler.Write(ctx, fmt.Errorf("ERR: Multiple commands not supported"))
- if err != nil {
- slog.Debug("Write error, connection closed possibly", slog.String("id", t.id), slog.Any("error", err))
- return err
- }
- }
-
- err = t.isAuthenticated(commands[0])
- if err != nil {
- writeErr := t.ioHandler.Write(ctx, err)
- if writeErr != nil {
- slog.Debug("Write error, connection closed possibly", slog.Any("error", errors.Join(err, writeErr)))
- return errors.Join(err, writeErr)
- }
- return nil
- }
-
- t.handleCmdRequestWithTimeout(ctx, errChan, commands, false, defaultRequestTimeout)
- return nil
-}
-
-func (t *BaseIOThread) handleCmdRequestWithTimeout(ctx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) {
- execCtx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
- t.executeCommandHandler(execCtx, errChan, commands, isWatchNotification)
-}
-
-func (t *BaseIOThread) executeCommandHandler(execCtx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) {
- // Retrieve metadata for the command to determine if multisharding is supported.
- meta, ok := CommandsMeta[commands[0].Cmd]
- if ok && meta.preProcessing {
- if err := meta.preProcessResponse(t, commands[0]); err != nil {
- e := t.ioHandler.Write(execCtx, err)
- if e != nil {
- slog.Debug("Error executing for io-thread", slog.String("id", t.id), slog.Any("error", err))
- }
- }
- }
-
- err := t.executeCommand(execCtx, commands[0], isWatchNotification)
- if err != nil {
- slog.Error("Error executing command", slog.String("id", t.id), slog.Any("error", err))
- if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) {
- slog.Debug("Connection closed for io-thread", slog.String("id", t.id), slog.Any("error", err))
- errChan <- err
- }
- }
-}
-
-func (t *BaseIOThread) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error {
- // Break down the single command into multiple commands if multisharding is supported.
- // The length of cmdList helps determine how many shards to wait for responses.
- cmdList := make([]*cmd.DiceDBCmd, 0)
- var watchLabel string
-
- // Retrieve metadata for the command to determine if multisharding is supported.
- meta, ok := CommandsMeta[diceDBCmd.Cmd]
- if !ok {
- // If no metadata exists, treat it as a single command and not migrated
- cmdList = append(cmdList, diceDBCmd)
- } else {
- // Depending on the command type, decide how to handle it.
- switch meta.CmdType {
- case Global:
- // If it's a global command, process it immediately without involving any shards.
- err := t.ioHandler.Write(ctx, meta.IOThreadHandler(diceDBCmd.Args))
- slog.Debug("Error executing command on io-thread", slog.String("id", t.id), slog.Any("error", err))
- return err
-
- case SingleShard:
- // For single-shard or custom commands, process them without breaking up.
- cmdList = append(cmdList, diceDBCmd)
-
- case MultiShard, AllShard:
- var err error
- // If the command supports multisharding, break it down into multiple commands.
- cmdList, err = meta.decomposeCommand(ctx, t, diceDBCmd)
- if err != nil {
- var ioErr error
- // Check if it's a CustomError
- var customErr *diceerrors.PreProcessError
- if errors.As(err, &customErr) {
- ioErr = t.ioHandler.Write(ctx, customErr.Result)
- } else {
- ioErr = t.ioHandler.Write(ctx, err)
- }
- if ioErr != nil {
- slog.Debug("Error executing for io-thread", slog.String("id", t.id), slog.Any("error", ioErr))
- }
- return ioErr
- }
-
- case Custom:
- return t.handleCustomCommands(ctx, diceDBCmd)
-
- case Watch:
- // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass
- // it along as is.
- // Modify the command name to remove the .WATCH suffix, this will allow us to generate a consistent
- // fingerprint (which uses the command name without the suffix)
- diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6]
-
- // check if the last argument is a watch label
- label := diceDBCmd.Args[len(diceDBCmd.Args)-1]
- if _, err := uuid.Parse(label); err == nil {
- watchLabel = label
-
- // remove the watch label from the args
- diceDBCmd.Args = diceDBCmd.Args[:len(diceDBCmd.Args)-1]
- }
-
- watchCmd := &cmd.DiceDBCmd{
- Cmd: diceDBCmd.Cmd,
- Args: diceDBCmd.Args,
- }
- cmdList = append(cmdList, watchCmd)
- isWatchNotification = true
-
- case Unwatch:
- // Generate the Cmd being unwatched. All we need to do is remove the .UNWATCH suffix from the command and pass
- // it along as is.
- // Modify the command name to remove the .UNWATCH suffix, this will allow us to generate a consistent
- // fingerprint (which uses the command name without the suffix)
- diceDBCmd.Cmd = diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-8]
- watchCmd := &cmd.DiceDBCmd{
- Cmd: diceDBCmd.Cmd,
- Args: diceDBCmd.Args,
- }
- cmdList = append(cmdList, watchCmd)
- isWatchNotification = false
- }
- }
-
- // Unsubscribe Unwatch command type
- if meta.CmdType == Unwatch {
- return t.handleCommandUnwatch(ctx, cmdList)
- }
-
- // Scatter the broken-down commands to the appropriate shards.
- if err := t.scatter(ctx, cmdList, meta.CmdType); err != nil {
- return err
- }
-
- // Gather the responses from the shards and write them to the buffer.
- if err := t.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel); err != nil {
- return err
- }
-
- if meta.CmdType == Watch {
- // Proceed to subscribe after successful execution
- t.handleCommandWatch(cmdList)
- }
-
- return nil
-}
-
-func (t *BaseIOThread) handleCustomCommands(ctx context.Context, diceDBCmd *cmd.DiceDBCmd) error {
- // if command is of type Custom, write a custom logic around it
- switch diceDBCmd.Cmd {
- case CmdAuth:
- err := t.ioHandler.Write(ctx, t.RespAuth(diceDBCmd.Args))
- if err != nil {
- slog.Error("Error sending auth response to io-thread", slog.String("id", t.id), slog.Any("error", err))
- }
- return err
- case CmdEcho:
- err := t.ioHandler.Write(ctx, RespEcho(diceDBCmd.Args))
- if err != nil {
- slog.Error("Error sending echo response to io-thread", slog.String("id", t.id), slog.Any("error", err))
- }
- return err
- case CmdAbort:
- err := t.ioHandler.Write(ctx, clientio.OK)
- if err != nil {
- slog.Error("Error sending abort response to io-thread", slog.String("id", t.id), slog.Any("error", err))
- }
- slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", t.id))
- t.globalErrorChan <- diceerrors.ErrAborted
- return err
- case CmdPing:
- err := t.ioHandler.Write(ctx, RespPING(diceDBCmd.Args))
- if err != nil {
- slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err))
- }
- return err
- case CmdHello:
- err := t.ioHandler.Write(ctx, RespHello(diceDBCmd.Args))
- if err != nil {
- slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err))
- }
- return err
- case CmdSleep:
- err := t.ioHandler.Write(ctx, RespSleep(diceDBCmd.Args))
- if err != nil {
- slog.Error("Error sending ping response to io-thread", slog.String("id", t.id), slog.Any("error", err))
- }
- return err
- default:
- return diceerrors.ErrUnknownCmd(diceDBCmd.Cmd)
- }
-}
-
-// handleCommandWatch sends a watch subscription request to the watch manager.
-func (t *BaseIOThread) handleCommandWatch(cmdList []*cmd.DiceDBCmd) {
- t.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
- Subscribe: true,
- WatchCmd: cmdList[len(cmdList)-1],
- AdhocReqChan: t.adhocReqChan,
- }
-}
-
-// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client.
-// The response is sent before the unwatch request is processed by the watch manager.
-func (t *BaseIOThread) handleCommandUnwatch(ctx context.Context, cmdList []*cmd.DiceDBCmd) error {
- // extract the fingerprint
- command := cmdList[len(cmdList)-1]
- fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32)
- if parseErr != nil {
- err := t.ioHandler.Write(ctx, diceerrors.ErrInvalidFingerprint)
- if err != nil {
- return fmt.Errorf("error sending push response to client: %v", err)
- }
- return parseErr
- }
-
- // send the unsubscribe request
- t.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
- Subscribe: false,
- AdhocReqChan: t.adhocReqChan,
- Fingerprint: uint32(fp),
- }
-
- err := t.ioHandler.Write(ctx, clientio.RespOK)
- if err != nil {
- return fmt.Errorf("error sending push response to client: %v", err)
- }
- return nil
-}
-
-// scatter distributes the DiceDB commands to the respective shards based on the key.
-// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing.
-func (t *BaseIOThread) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd, cmdType CmdType) error {
- // Otherwise check for the shard based on the key using hash
- // and send it to the particular shard
- // Check if the context has been canceled or expired.
- select {
- case <-ctx.Done():
- // If the context is canceled, return the error associated with it.
- return ctx.Err()
- default:
- // Proceed with the default case when the context is not canceled.
-
- if cmdType == AllShard {
- // If the command type is for all shards, iterate over all available shards.
- for i := uint8(0); i < uint8(t.shardManager.GetShardCount()); i++ {
- // Get the shard ID (i) and its associated request channel.
- shardID, responseChan := i, t.shardManager.GetShard(i).ReqChan
-
- // Send a StoreOp operation to the shard's request channel.
- responseChan <- &ops.StoreOp{
- SeqID: i, // Sequence ID for this operation.
- RequestID: GenerateUniqueRequestID(), // Unique identifier for the request.
- Cmd: cmds[0], // Command to be executed, using the first command in cmds.
- IOThreadID: t.id, // ID of the current io-thread.
- ShardID: shardID, // ID of the shard handling this operation.
- Client: nil, // Client information (if applicable).
- }
- }
- } else {
- // If the command type is specific to certain commands, process them individually.
- for i := uint8(0); i < uint8(len(cmds)); i++ {
- // Determine the appropriate shard for the current command using a routing key.
- shardID, responseChan := t.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i]))
-
- // Send a StoreOp operation to the shard's request channel.
- responseChan <- &ops.StoreOp{
- SeqID: i, // Sequence ID for this operation.
- RequestID: GenerateUniqueRequestID(), // Unique identifier for the request.
- Cmd: cmds[i], // Command to be executed, using the current command in cmds.
- IOThreadID: t.id, // ID of the current io-thread.
- ShardID: shardID, // ID of the shard handling this operation.
- Client: nil, // Client information (if applicable).
- }
- }
- }
- }
-
- return nil
-}
-
-// getRoutingKeyFromCommand determines the key used for shard routing
-func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string {
- if len(diceDBCmd.Args) > 0 {
- return diceDBCmd.Args[0]
- }
- return diceDBCmd.Cmd
-}
-
-// gather collects the responses from multiple shards and writes the results into the provided buffer.
-// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard).
-func (t *BaseIOThread) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) error {
- // Collect responses from all shards
- storeOp, err := t.gatherResponses(ctx, numCmds)
- if err != nil {
- return err
- }
-
- if len(storeOp) == 0 {
- slog.Error("No response from shards",
- slog.String("id", t.id),
- slog.String("command", diceDBCmd.Cmd))
- return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd)
- }
-
- if isWatchNotification {
- return t.handleWatchNotification(ctx, diceDBCmd, storeOp[0], watchLabel)
- }
-
- // Process command based on its type
- cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd]
- if !ok {
- return t.handleUnsupportedCommand(ctx, storeOp[0])
- }
-
- return t.handleCommand(ctx, cmdMeta, diceDBCmd, storeOp)
-}
-
-// gatherResponses collects responses from all shards
-func (t *BaseIOThread) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) {
- storeOp := make([]ops.StoreResponse, 0, numCmds)
-
- for numCmds > 0 {
- select {
- case <-ctx.Done():
- slog.Error("Timed out waiting for response from shards",
- slog.String("id", t.id),
- slog.Any("error", ctx.Err()))
- return nil, ctx.Err()
-
- case resp, ok := <-t.responseChan:
- if ok {
- storeOp = append(storeOp, *resp)
- }
- numCmds--
-
- case sError, ok := <-t.shardManager.ShardErrorChan:
- if ok {
- slog.Error("Error from shard",
- slog.String("id", t.id),
- slog.Any("error", sError))
- return nil, sError.Error
- }
- }
- }
-
- return storeOp, nil
-}
-
-// handleWatchNotification processes watch notification responses
-func (t *BaseIOThread) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) error {
- fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint())
-
- // if watch label is not empty, then this is the first response for the watch command
- // hence, we will send the watch label as part of the response
- firstRespElem := diceDBCmd.Cmd
- if watchLabel != "" {
- firstRespElem = watchLabel
- }
-
- if resp.EvalResponse.Error != nil {
- return t.writeResponse(ctx, querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error))
- }
-
- return t.writeResponse(ctx, querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result))
-}
-
-// handleUnsupportedCommand processes commands not in CommandsMeta
-func (t *BaseIOThread) handleUnsupportedCommand(ctx context.Context, resp ops.StoreResponse) error {
- if resp.EvalResponse.Error != nil {
- return t.writeResponse(ctx, resp.EvalResponse.Error)
- }
- return t.writeResponse(ctx, resp.EvalResponse.Result)
-}
-
-// handleCommand processes commands based on their type
-func (t *BaseIOThread) handleCommand(ctx context.Context, cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error {
- var err error
-
- switch cmdMeta.CmdType {
- case SingleShard, Custom:
- if storeOp[0].EvalResponse.Error != nil {
- err = t.writeResponse(ctx, storeOp[0].EvalResponse.Error)
- } else {
- err = t.writeResponse(ctx, storeOp[0].EvalResponse.Result)
- }
-
- if err == nil && t.wl != nil {
- if err := t.wl.LogCommand([]byte(fmt.Sprintf("%s %s", diceDBCmd.Cmd, strings.Join(diceDBCmd.Args, " ")))); err != nil {
- return err
- }
- }
- case MultiShard, AllShard:
- err = t.writeResponse(ctx, cmdMeta.composeResponse(storeOp...))
-
- if err == nil && t.wl != nil {
- if err := t.wl.LogCommand([]byte(fmt.Sprintf("%s %s", diceDBCmd.Cmd, strings.Join(diceDBCmd.Args, " ")))); err != nil {
- return err
- }
- }
- default:
- slog.Error("Unknown command type",
- slog.String("id", t.id),
- slog.String("command", diceDBCmd.Cmd),
- slog.Any("evalResp", storeOp))
- err = t.writeResponse(ctx, diceerrors.ErrInternalServer)
- }
- return err
-}
-
-// writeResponse handles writing responses and logging errors
-func (t *BaseIOThread) writeResponse(ctx context.Context, response interface{}) error {
- err := t.ioHandler.Write(ctx, response)
- if err != nil {
- slog.Debug("Error sending response to client",
- slog.String("id", t.id),
- slog.Any("error", err))
- }
- return err
-}
-
-func (t *BaseIOThread) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error {
- if diceDBCmd.Cmd != auth.Cmd && !t.Session.IsActive() {
- return errors.New("NOAUTH Authentication required")
- }
-
- return nil
-}
-
func (t *BaseIOThread) Stop() error {
slog.Info("Stopping io-thread", slog.String("id", t.id))
t.Session.Expire()
return nil
}
-
-func GenerateUniqueRequestID() uint32 {
- return atomic.AddUint32(&requestCounter, 1)
-}
diff --git a/internal/iothread/manager.go b/internal/iothread/manager.go
index 3cf95eaed..9bcfcb451 100644
--- a/internal/iothread/manager.go
+++ b/internal/iothread/manager.go
@@ -20,15 +20,12 @@ import (
"errors"
"sync"
"sync/atomic"
-
- "github.com/dicedb/dice/internal/shard"
)
type Manager struct {
connectedClients sync.Map
numIOThreads atomic.Int32
maxClients int32
- shardManager *shard.ShardManager
mu sync.Mutex
}
@@ -37,10 +34,9 @@ var (
ErrIOThreadNotFound = errors.New("io-thread not found")
)
-func NewManager(maxClients int32, sm *shard.ShardManager) *Manager {
+func NewManager(maxClients int32) *Manager {
return &Manager{
- maxClients: maxClients,
- shardManager: sm,
+ maxClients: maxClients,
}
}
@@ -53,14 +49,6 @@ func (m *Manager) RegisterIOThread(ioThread IOThread) error {
}
m.connectedClients.Store(ioThread.ID(), ioThread)
- responseChan := ioThread.(*BaseIOThread).responseChan
- preprocessingChan := ioThread.(*BaseIOThread).preprocessingChan
-
- if responseChan != nil && preprocessingChan != nil {
- m.shardManager.RegisterIOThread(ioThread.ID(), responseChan, preprocessingChan) // TODO: Change responseChan type to ShardResponse
- } else if responseChan != nil && preprocessingChan == nil {
- m.shardManager.RegisterIOThread(ioThread.ID(), responseChan, nil)
- }
m.numIOThreads.Add(1)
return nil
@@ -88,8 +76,6 @@ func (m *Manager) UnregisterIOThread(id string) error {
return ErrIOThreadNotFound
}
- m.shardManager.UnregisterIOThread(id)
m.numIOThreads.Add(-1)
-
return nil
}
diff --git a/internal/ops/store_op.go b/internal/ops/store_op.go
index 376223a00..236f63fea 100644
--- a/internal/ops/store_op.go
+++ b/internal/ops/store_op.go
@@ -27,8 +27,8 @@ type StoreOp struct {
RequestID uint32 // RequestID identifies the request that this StoreOp belongs to
Cmd *cmd.DiceDBCmd // Cmd is the atomic Store command (e.g., GET, SET)
ShardID uint8 // ShardID of the shard on which the Store command will be executed
- IOThreadID string // IOThreadID is the ID of the io-thread that sent this Store operation
- Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the IOThreadID in the future
+ CmdHandlerID string // CmdHandlerID is the ID of the command handler that sent this Store operation
+ Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the CmdHandlerID in the future
HTTPOp bool // HTTPOp is true if this Store operation is an HTTP operation
WebsocketOp bool // WebsocketOp is true if this Store operation is a Websocket operation
PreProcessing bool // PreProcessing indicates whether a comamnd operation requires preprocessing before execution. This is mainly used is multi-step-multi-shard commands
diff --git a/internal/server/httpws/httpServer.go b/internal/server/httpws/httpServer.go
index ab24bce51..f4b82da6d 100644
--- a/internal/server/httpws/httpServer.go
+++ b/internal/server/httpws/httpServer.go
@@ -28,7 +28,7 @@ import (
"sync"
"time"
- "github.com/dicedb/dice/internal/iothread"
+ "github.com/dicedb/dice/internal/commandhandler"
"github.com/dicedb/dice/internal/eval"
"github.com/dicedb/dice/internal/server/abstractserver"
@@ -44,8 +44,9 @@ import (
)
const (
- Abort = "ABORT"
- stringNil = "(nil)"
+ Abort = "ABORT"
+ stringNil = "(nil)"
+ httpCmdHandlerID = "httpServer"
)
var unimplementedCommands = map[string]bool{
@@ -113,7 +114,7 @@ func (s *HTTPServer) Run(ctx context.Context) error {
httpCtx, cancelHTTP := context.WithCancel(ctx)
defer cancelHTTP()
- s.shardManager.RegisterIOThread("httpServer", s.ioChan, nil)
+ s.shardManager.RegisterCommandHandler(httpCmdHandlerID, s.ioChan, nil)
wg.Add(1)
go func() {
@@ -157,7 +158,7 @@ func (s *HTTPServer) DiceHTTPHandler(writer http.ResponseWriter, request *http.R
return
}
- if iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.MultiShard {
+ if commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.MultiShard {
writeErrorResponse(writer, http.StatusBadRequest, "unsupported command",
"Unsupported command received", slog.String("cmd", diceDBCmd.Cmd))
return
@@ -179,10 +180,10 @@ func (s *HTTPServer) DiceHTTPHandler(writer http.ResponseWriter, request *http.R
// send request to Shard Manager
s.shardManager.GetShard(0).ReqChan <- &ops.StoreOp{
- Cmd: diceDBCmd,
- IOThreadID: "httpServer",
- ShardID: 0,
- HTTPOp: true,
+ Cmd: diceDBCmd,
+ CmdHandlerID: httpCmdHandlerID,
+ ShardID: 0,
+ HTTPOp: true,
}
// Wait for response
@@ -230,11 +231,11 @@ func (s *HTTPServer) DiceHTTPQwatchHandler(writer http.ResponseWriter, request *
qwatchClient := comm.NewHTTPQwatchClient(s.qwatchResponseChan, clientIdentifierID)
// Prepare the store operation
storeOp := &ops.StoreOp{
- Cmd: diceDBCmd,
- IOThreadID: "httpServer",
- ShardID: 0,
- Client: qwatchClient,
- HTTPOp: true,
+ Cmd: diceDBCmd,
+ CmdHandlerID: httpCmdHandlerID,
+ ShardID: 0,
+ Client: qwatchClient,
+ HTTPOp: true,
}
slog.Info("Registered client for watching query", slog.Any("clientID", clientIdentifierID),
@@ -354,9 +355,9 @@ func (s *HTTPServer) writeResponse(writer http.ResponseWriter, result *ops.Store
// Check if the command is migrated, if it is we use EvalResponse values
// else we use RESPParser to decode the response
- _, ok := iothread.CommandsMeta[diceDBCmd.Cmd]
+ _, ok := commandhandler.CommandsMeta[diceDBCmd.Cmd]
// TODO: Remove this conditional check and if (true) condition when all commands are migrated
- if !ok || iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.Custom {
+ if !ok || commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.Custom {
responseValue, err = DecodeEvalResponse(result.EvalResponse)
if err != nil {
slog.Error("Error decoding response", "error", err)
diff --git a/internal/server/httpws/websocketServer.go b/internal/server/httpws/websocketServer.go
index 89e8ed4a1..f1fb8b142 100644
--- a/internal/server/httpws/websocketServer.go
+++ b/internal/server/httpws/websocketServer.go
@@ -30,7 +30,7 @@ import (
"syscall"
"time"
- "github.com/dicedb/dice/internal/iothread"
+ "github.com/dicedb/dice/internal/commandhandler"
"github.com/dicedb/dice/internal/server/abstractserver"
"github.com/dicedb/dice/internal/wal"
@@ -46,9 +46,12 @@ import (
"golang.org/x/exp/rand"
)
-const Qwatch = "Q.WATCH"
-const Qunwatch = "Q.UNWATCH"
-const Subscribe = "SUBSCRIBE"
+const (
+ Qwatch = "Q.WATCH"
+ Qunwatch = "Q.UNWATCH"
+ Subscribe = "SUBSCRIBE"
+ wsCmdHandlerID = "wsServer"
+)
var unimplementedCommandsWebsocket = map[string]bool{
Qunwatch: true,
@@ -96,7 +99,7 @@ func (s *WebsocketServer) Run(ctx context.Context) error {
websocketCtx, cancelWebsocket := context.WithCancel(ctx)
defer cancelWebsocket()
- s.shardManager.RegisterIOThread("wsServer", s.ioChan, nil)
+ s.shardManager.RegisterCommandHandler(wsCmdHandlerID, s.ioChan, nil)
wg.Add(1)
go func() {
@@ -170,7 +173,7 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques
continue
}
- if iothread.CommandsMeta[diceDBCmd.Cmd].CmdType == iothread.MultiShard {
+ if commandhandler.CommandsMeta[diceDBCmd.Cmd].CmdType == commandhandler.MultiShard {
if err := WriteResponseWithRetries(conn, []byte("error: unsupported command"), maxRetries); err != nil {
slog.Debug(fmt.Sprintf("Error writing message: %v", err))
}
@@ -192,10 +195,10 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques
// create request
sp := &ops.StoreOp{
- Cmd: diceDBCmd,
- IOThreadID: "wsServer",
- ShardID: 0,
- WebsocketOp: true,
+ Cmd: diceDBCmd,
+ CmdHandlerID: wsCmdHandlerID,
+ ShardID: 0,
+ WebsocketOp: true,
}
// handle q.watch commands
@@ -295,7 +298,7 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D
var responseValue interface{}
// Check if the command is migrated, if it is we use EvalResponse values
// else we use RESPParser to decode the response
- _, ok := iothread.CommandsMeta[diceDBCmd.Cmd]
+ _, ok := commandhandler.CommandsMeta[diceDBCmd.Cmd]
// TODO: Remove this conditional check and if (true) condition when all commands are migrated
if !ok {
responseValue, err = DecodeEvalResponse(response.EvalResponse)
diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go
index 9a23184dc..3c83b71d1 100644
--- a/internal/server/resp/server.go
+++ b/internal/server/resp/server.go
@@ -27,6 +27,8 @@ import (
"syscall"
"time"
+ "github.com/dicedb/dice/internal/commandhandler"
+ "github.com/dicedb/dice/internal/ops"
"github.com/dicedb/dice/internal/server/abstractserver"
"github.com/dicedb/dice/internal/wal"
@@ -37,13 +39,13 @@ import (
"github.com/dicedb/dice/internal/clientio/iohandler/netconn"
respparser "github.com/dicedb/dice/internal/clientio/requestparser/resp"
"github.com/dicedb/dice/internal/iothread"
- "github.com/dicedb/dice/internal/ops"
"github.com/dicedb/dice/internal/shard"
)
var (
- ioThreadCounter uint64
- startTime = time.Now().UnixNano() / int64(time.Millisecond)
+ ioThreadCounter uint64
+ cmdHandlerCounter uint64
+ startTime = time.Now().UnixNano() / int64(time.Millisecond)
)
var (
@@ -61,6 +63,7 @@ type Server struct {
serverFD int
connBacklogSize int
ioThreadManager *iothread.Manager
+ cmdHandlerManager *commandhandler.Registry
shardManager *shard.ShardManager
watchManager *watchmanager.Manager
cmdWatchSubscriptionChan chan watchmanager.WatchSubscription
@@ -68,13 +71,15 @@ type Server struct {
wl wal.AbstractWAL
}
-func NewServer(shardManager *shard.ShardManager, ioThreadManager *iothread.Manager,
- cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, wl wal.AbstractWAL) *Server {
+func NewServer(shardManager *shard.ShardManager, ioThreadManager *iothread.Manager, cmdHandlerManager *commandhandler.Registry,
+ cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent,
+ globalErrChan chan error, wl wal.AbstractWAL) *Server {
return &Server{
Host: config.DiceConfig.RespServer.Addr,
Port: config.DiceConfig.RespServer.Port,
connBacklogSize: DefaultConnBacklogSize,
ioThreadManager: ioThreadManager,
+ cmdHandlerManager: cmdHandlerManager,
shardManager: shardManager,
watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan, cmdWatchChan),
cmdWatchSubscriptionChan: cmdWatchSubscriptionChan,
@@ -207,13 +212,22 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou
return err
}
- parser := respparser.NewParser()
+ // create a new io-thread
+ ioThreadID := GenerateUniqueIOThreadID()
+ ioThreadReadChan := make(chan []byte) // for sending data to the command handler from the io-thread
+ ioThreadWriteChan := make(chan interface{}) // for sending data to the io-thread from the command handler
+ ioThreadErrChan := make(chan error, 1) // for receiving errors from the io-thread
+ thread := iothread.NewIOThread(ioThreadID, ioHandler, ioThreadReadChan, ioThreadWriteChan, ioThreadErrChan)
+ // For each io-thread, we create a dedicated command handler - 1:1 mapping
+ cmdHandlerID := GenerateUniqueCommandHandlerID()
+ parser := respparser.NewParser()
responseChan := make(chan *ops.StoreResponse) // responseChan is used for handling common responses from shards
preprocessingChan := make(chan *ops.StoreResponse) // preprocessingChan is specifically for handling responses from shards for commands that require preprocessing
- ioThreadID := GenerateUniqueIOThreadID()
- thread := iothread.NewIOThread(ioThreadID, responseChan, preprocessingChan, s.cmdWatchSubscriptionChan, ioHandler, parser, s.shardManager, s.globalErrorChan, s.wl)
+ handler := commandhandler.NewCommandHandler(cmdHandlerID, responseChan, preprocessingChan,
+ s.cmdWatchSubscriptionChan, parser, s.shardManager, s.globalErrorChan,
+ ioThreadReadChan, ioThreadWriteChan, ioThreadErrChan, s.wl)
// Register the io-thread with the manager
err = s.ioThreadManager.RegisterIOThread(thread)
@@ -223,6 +237,15 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou
wg.Add(1)
go s.startIOThread(ctx, wg, thread)
+
+ // Register the command handler with the manager
+ err = s.cmdHandlerManager.RegisterCommandHandler(handler)
+ if err != nil {
+ return err
+ }
+
+ wg.Add(1)
+ go s.startCommandHandler(ctx, wg, handler)
}
}
}
@@ -243,12 +266,34 @@ func (s *Server) startIOThread(ctx context.Context, wg *sync.WaitGroup, thread *
}
}
+func (s *Server) startCommandHandler(ctx context.Context, wg *sync.WaitGroup, cmdHandler *commandhandler.BaseCommandHandler) {
+ wg.Done()
+ defer func(wm *commandhandler.Registry, id string) {
+ err := wm.UnregisterCommandHandler(id)
+ if err != nil {
+ slog.Warn("Failed to unregister command handler", slog.String("id", id), slog.Any("error", err))
+ }
+ }(s.cmdHandlerManager, cmdHandler.ID())
+ ctx2, cancel := context.WithCancel(ctx)
+ defer cancel()
+ err := cmdHandler.Start(ctx2)
+ if err != nil {
+ slog.Debug("CommandHandler stopped", slog.String("id", cmdHandler.ID()), slog.Any("error", err))
+ }
+}
+
func GenerateUniqueIOThreadID() string {
count := atomic.AddUint64(&ioThreadCounter, 1)
timestamp := time.Now().UnixNano()/int64(time.Millisecond) - startTime
return fmt.Sprintf("W-%d-%d", timestamp, count)
}
+func GenerateUniqueCommandHandlerID() string {
+ count := atomic.AddUint64(&cmdHandlerCounter, 1)
+ timestamp := time.Now().UnixNano()/int64(time.Millisecond) - startTime
+ return fmt.Sprintf("W-%d-%d", timestamp, count)
+}
+
func (s *Server) Shutdown() {
// Not implemented
}
diff --git a/internal/shard/shard_manager.go b/internal/shard/shard_manager.go
index 7fdc83e66..4ff603817 100644
--- a/internal/shard/shard_manager.go
+++ b/internal/shard/shard_manager.go
@@ -120,15 +120,16 @@ func (manager *ShardManager) GetShard(id ShardID) *ShardThread {
return nil
}
-// RegisterIOThread registers a io-thread with all Shards present in the ShardManager.
-func (manager *ShardManager) RegisterIOThread(id string, request, processing chan *ops.StoreResponse) {
+// RegisterCommandHandler registers a command handler with all Shards present in the ShardManager.
+func (manager *ShardManager) RegisterCommandHandler(id string, request, processing chan *ops.StoreResponse) {
for _, shard := range manager.shards {
- shard.registerIOThread(id, request, processing)
+ shard.registerCommandHandler(id, request, processing)
}
}
-func (manager *ShardManager) UnregisterIOThread(id string) {
+// UnregisterCommandHandler unregisters a command handler from all Shards present in the ShardManager.
+func (manager *ShardManager) UnregisterCommandHandler(id string) {
for _, shard := range manager.shards {
- shard.unregisterIOThread(id)
+ shard.unregisterCommandHandler(id)
}
}
diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go
index 02537e4f6..b7c93d14f 100644
--- a/internal/shard/shard_thread.go
+++ b/internal/shard/shard_thread.go
@@ -37,23 +37,24 @@ type ShardError struct {
Error error // Error is the error that occurred
}
-// IOChannels holds the communication channels for an io-thread.
-// It contains both the common response channel and the preprocessing response channel.
-type IOChannels struct {
- CommonResponseChan chan *ops.StoreResponse // CommonResponseChan is used to send standard responses for io-thread operations.
+// CmdHandlerChannels holds the communication channels for a Command Handler.
+// It contains both the response channel and the preprocessing response channel.
+type CmdHandlerChannels struct {
+ ResponseChan chan *ops.StoreResponse // ResponseChan is used to send standard responses for Command Handler operations.
PreProcessingResponseChan chan *ops.StoreResponse // PreProcessingResponseChan is used to send responses related to preprocessing operations.
}
type ShardThread struct {
- id ShardID // id is the unique identifier for the shard.
- store *dstore.Store // store that the shard is responsible for.
- ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests.
- ioThreadMap map[string]IOChannels // ioThreadMap maps each io-thread id to its corresponding IOChannels, containing both the common and preprocessing response channels.
- mu sync.RWMutex // mu is the ioThreadMap's mutex for thread safety.
- globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors.
- shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors.
- lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks.
- cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks.
+ id ShardID // id is the unique identifier for the shard.
+ store *dstore.Store // store that the shard is responsible for.
+ ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests.
+ // cmdHandlerMap maps each command handler id to its corresponding CommandHandlerChannels, containing both the common and preprocessing response channels.
+ cmdHandlerMap map[string]CmdHandlerChannels
+ mu sync.RWMutex // mu is the cmdHandlerMap's mutex for thread safety.
+ globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors.
+ shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors.
+ lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks.
+ cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks.
}
// NewShardThread creates a new ShardThread instance with the given shard id and error channel.
@@ -63,7 +64,7 @@ func NewShardThread(id ShardID, gec chan error, sec chan *ShardError,
id: id,
store: dstore.NewStore(cmdWatchChan, evictionStrategy),
ReqChan: make(chan *ops.StoreOp, 1000),
- ioThreadMap: make(map[string]IOChannels),
+ cmdHandlerMap: make(map[string]CmdHandlerChannels),
globalErrorChan: gec,
shardErrorChan: sec,
lastCronExecTime: utils.GetCurrentTime(),
@@ -95,30 +96,30 @@ func (shard *ShardThread) runCronTasks() {
shard.lastCronExecTime = utils.GetCurrentTime()
}
-func (shard *ShardThread) registerIOThread(id string, responseChan, preprocessingChan chan *ops.StoreResponse) {
+func (shard *ShardThread) registerCommandHandler(id string, responseChan, preprocessingChan chan *ops.StoreResponse) {
shard.mu.Lock()
- shard.ioThreadMap[id] = IOChannels{
- CommonResponseChan: responseChan,
+ shard.cmdHandlerMap[id] = CmdHandlerChannels{
+ ResponseChan: responseChan,
PreProcessingResponseChan: preprocessingChan,
}
shard.mu.Unlock()
}
-func (shard *ShardThread) unregisterIOThread(id string) {
+func (shard *ShardThread) unregisterCommandHandler(id string) {
shard.mu.Lock()
- delete(shard.ioThreadMap, id)
+ delete(shard.cmdHandlerMap, id)
shard.mu.Unlock()
}
// processRequest processes a Store operation for the shard.
func (shard *ShardThread) processRequest(op *ops.StoreOp) {
shard.mu.RLock()
- ioChannels, ok := shard.ioThreadMap[op.IOThreadID]
+ channels, ok := shard.cmdHandlerMap[op.CmdHandlerID]
shard.mu.RUnlock()
- ioThreadChan := ioChannels.CommonResponseChan
- preProcessChan := ioChannels.PreProcessingResponseChan
+ cmdHandlerChan := channels.ResponseChan
+ preProcessChan := channels.PreProcessingResponseChan
sp := &ops.StoreResponse{
RequestID: op.RequestID,
@@ -140,11 +141,11 @@ func (shard *ShardThread) processRequest(op *ops.StoreOp) {
} else {
shard.shardErrorChan <- &ShardError{
ShardID: shard.id,
- Error: fmt.Errorf(diceerrors.IOThreadNotFoundErr, op.IOThreadID),
+ Error: fmt.Errorf(diceerrors.CmdHandlerNotFoundErr, op.CmdHandlerID),
}
}
- ioThreadChan <- sp
+ cmdHandlerChan <- sp
}
// cleanup handles cleanup logic when the shard stops.
diff --git a/main.go b/main.go
index e24f76a99..01ea98558 100644
--- a/main.go
+++ b/main.go
@@ -34,6 +34,7 @@ import (
"github.com/dicedb/dice/internal/server/httpws"
"github.com/dicedb/dice/internal/cli"
+ "github.com/dicedb/dice/internal/commandhandler"
"github.com/dicedb/dice/internal/logger"
"github.com/dicedb/dice/internal/server/abstractserver"
"github.com/dicedb/dice/internal/wal"
@@ -145,8 +146,10 @@ func main() {
}
defer stopProfiling()
}
- ioThreadManager := iothread.NewManager(config.DiceConfig.Performance.MaxClients, shardManager)
- respServer := resp.NewServer(shardManager, ioThreadManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh, wl)
+ ioThreadManager := iothread.NewManager(config.DiceConfig.Performance.MaxClients)
+ cmdHandlerManager := commandhandler.NewRegistry(config.DiceConfig.Performance.MaxClients, shardManager)
+
+ respServer := resp.NewServer(shardManager, ioThreadManager, cmdHandlerManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh, wl)
serverWg.Add(1)
go runServer(ctx, &serverWg, respServer, serverErrCh)