From 0ffb45d6850862ef161a326f12325966da91879e Mon Sep 17 00:00:00 2001 From: Prateek Singh Rathore Date: Wed, 4 Dec 2024 20:52:39 +0530 Subject: [PATCH] consolidated write to io channel in one place --- internal/commandhandler/commandhandler.go | 158 ++++++++++------------ 1 file changed, 70 insertions(+), 88 deletions(-) diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go index 6313cac2e..623d9c848 100644 --- a/internal/commandhandler/commandhandler.go +++ b/internal/commandhandler/commandhandler.go @@ -88,85 +88,76 @@ func (h *BaseCommandHandler) Start(ctx context.Context) error { case err := <-errChan: return h.handleError(err) case data := <-h.ioThreadReadChan: - if err := h.processCommand(ctx, &data, h.globalErrorChan); err != nil { + resp, err := h.processCommand(ctx, &data, h.globalErrorChan) + if err != nil { + h.sendResponseToIOThread(err) return err } + h.sendResponseToIOThread(resp) } } } // processCommand processes commands recevied from io thread -func (h *BaseCommandHandler) processCommand(ctx context.Context, data *[]byte, errChan chan error) error { +func (h *BaseCommandHandler) processCommand(ctx context.Context, data *[]byte, gec chan error) (interface{}, error) { commands, err := h.parser.Parse(*data) if err != nil { slog.Debug("error parsing commands from io thread", slog.String("id", h.id), slog.Any("error", err)) - h.ioThreadWriteChan <- err - return nil + return nil, err } if len(commands) == 0 { slog.Debug("invalid request from io thread with zero length", slog.String("id", h.id)) - h.ioThreadWriteChan <- fmt.Errorf("ERR: Invalid request") - return nil + return nil, fmt.Errorf("ERR: Invalid request") } // DiceDB supports clients to send only one request at a time // We also need to ensure that the client is blocked until the response is received if len(commands) > 1 { - h.ioThreadWriteChan <- fmt.Errorf("ERR: Multiple commands not supported") - return nil + return nil, fmt.Errorf("ERR: Multiple commands not supported") } err = h.isAuthenticated(commands[0]) if err != nil { slog.Debug("command handler authentication failed", slog.String("id", h.id), slog.Any("error", err)) - h.ioThreadWriteChan <- err - return nil + return nil, err } - h.handleCmdRequestWithTimeout(ctx, errChan, commands, false, defaultRequestTimeout) - return nil + return h.handleCmdRequestWithTimeout(ctx, gec, commands, false, defaultRequestTimeout) } -func (h *BaseCommandHandler) handleCmdRequestWithTimeout(ctx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) { +func (h *BaseCommandHandler) handleCmdRequestWithTimeout(ctx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool, timeout time.Duration) (interface{}, error) { execCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - h.executeCommandHandler(execCtx, errChan, commands, isWatchNotification) -} - -func (h *BaseCommandHandler) handleError(err error) error { - if err != nil { - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { - slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) - return err - } - } - return fmt.Errorf("error writing response: %v", err) + return h.executeCommandHandler(execCtx, gec, commands, isWatchNotification) } -func (h *BaseCommandHandler) executeCommandHandler(execCtx context.Context, errChan chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) { +func (h *BaseCommandHandler) executeCommandHandler(execCtx context.Context, gec chan error, commands []*cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) { // Retrieve metadata for the command to determine if multisharding is supported. meta, ok := CommandsMeta[commands[0].Cmd] if ok && meta.preProcessing { if err := meta.preProcessResponse(h, commands[0]); err != nil { slog.Debug("error pre processing response", slog.String("id", h.id), slog.Any("error", err)) - h.ioThreadWriteChan <- err + return nil, err } } - err := h.executeCommand(execCtx, commands[0], isWatchNotification) + resp, err := h.executeCommand(execCtx, commands[0], isWatchNotification) + + // log error and send to global error channel if it's a connection error if err != nil { slog.Error("Error executing command", slog.String("id", h.id), slog.Any("error", err)) if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) - errChan <- err + gec <- err } - h.ioThreadWriteChan <- err } + + return resp, err } -func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error { +func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) (interface{}, error) { // Break down the single command into multiple commands if multisharding is supported. // The length of cmdList helps determine how many shards to wait for responses. cmdList := make([]*cmd.DiceDBCmd, 0) @@ -181,9 +172,8 @@ func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd. // Depending on the command type, decide how to handle it. switch meta.CmdType { case Global: - // If it's a global command, process it immediately without involving any shards. - h.ioThreadWriteChan <- meta.CmdHandlerFunction(diceDBCmd.Args) - return nil + // process global command immediately without involving any shards. + return meta.CmdHandlerFunction(diceDBCmd.Args), nil case SingleShard: // For single-shard or custom commands, process them without breaking up. @@ -198,11 +188,10 @@ func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd. // Check if it's a CustomError var customErr *diceerrors.PreProcessError if errors.As(err, &customErr) { - h.ioThreadWriteChan <- customErr.Result + return nil, fmt.Errorf("%v", customErr.Result) } else { - h.ioThreadWriteChan <- err + return nil, err } - return nil } case Custom: @@ -253,47 +242,42 @@ func (h *BaseCommandHandler) executeCommand(ctx context.Context, diceDBCmd *cmd. // Scatter the broken-down commands to the appropriate shards. if err := h.scatter(ctx, cmdList, meta.CmdType); err != nil { - return err + return nil, err } // Gather the responses from the shards and write them to the buffer. - if err := h.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel); err != nil { - return err + resp, err := h.gather(ctx, diceDBCmd, len(cmdList), isWatchNotification, watchLabel) + if err != nil { + return nil, err } + // Proceed to subscribe after successful execution if meta.CmdType == Watch { - // Proceed to subscribe after successful execution h.handleCommandWatch(cmdList) } - return nil + return resp, nil } -func (h *BaseCommandHandler) handleCustomCommands(diceDBCmd *cmd.DiceDBCmd) error { +func (h *BaseCommandHandler) handleCustomCommands(diceDBCmd *cmd.DiceDBCmd) (interface{}, error) { // if command is of type Custom, write a custom logic around it switch diceDBCmd.Cmd { case CmdAuth: - h.ioThreadWriteChan <- h.RespAuth(diceDBCmd.Args) - return nil + return h.RespAuth(diceDBCmd.Args), nil case CmdEcho: - h.ioThreadWriteChan <- RespEcho(diceDBCmd.Args) - return nil + return RespEcho(diceDBCmd.Args), nil case CmdAbort: - h.ioThreadWriteChan <- clientio.OK slog.Info("Received ABORT command, initiating server shutdown", slog.String("id", h.id)) h.globalErrorChan <- diceerrors.ErrAborted - return nil + return clientio.OK, nil case CmdPing: - h.ioThreadWriteChan <- RespPING(diceDBCmd.Args) - return nil + return RespPING(diceDBCmd.Args), nil case CmdHello: - h.ioThreadWriteChan <- RespHello(diceDBCmd.Args) - return nil + return RespHello(diceDBCmd.Args), nil case CmdSleep: - h.ioThreadWriteChan <- RespSleep(diceDBCmd.Args) - return nil + return RespSleep(diceDBCmd.Args), nil default: - return diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) + return nil, diceerrors.ErrUnknownCmd(diceDBCmd.Cmd) } } @@ -308,13 +292,12 @@ func (h *BaseCommandHandler) handleCommandWatch(cmdList []*cmd.DiceDBCmd) { // handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client. // The response is sent before the unwatch request is processed by the watch manager. -func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) error { +func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) (interface{}, error) { // extract the fingerprint command := cmdList[len(cmdList)-1] fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32) if parseErr != nil { - h.ioThreadWriteChan <- diceerrors.ErrInvalidFingerprint - return nil + return nil, diceerrors.ErrInvalidFingerprint } // send the unsubscribe request @@ -324,8 +307,7 @@ func (h *BaseCommandHandler) handleCommandUnwatch(cmdList []*cmd.DiceDBCmd) erro Fingerprint: uint32(fp), } - h.ioThreadWriteChan <- clientio.OK - return nil + return clientio.OK, nil } // scatter distributes the DiceDB commands to the respective shards based on the key. @@ -389,28 +371,28 @@ func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string { // gather collects the responses from multiple shards and writes the results into the provided buffer. // It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (h *BaseCommandHandler) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) error { +func (h *BaseCommandHandler) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool, watchLabel string) (interface{}, error) { // Collect responses from all shards storeOp, err := h.gatherResponses(ctx, numCmds) if err != nil { - return err + return nil, err } if len(storeOp) == 0 { slog.Error("No response from shards", slog.String("id", h.id), slog.String("command", diceDBCmd.Cmd)) - return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) + return nil, fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) } if isWatchNotification { - return h.handleWatchNotification(ctx, diceDBCmd, storeOp[0], watchLabel) + return h.handleWatchNotification(diceDBCmd, storeOp[0], watchLabel) } // Process command based on its type cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd] if !ok { - return h.handleUnsupportedCommand(ctx, storeOp[0]) + return h.handleUnsupportedCommand(storeOp[0]) } return h.handleCommand(cmdMeta, diceDBCmd, storeOp) @@ -448,7 +430,7 @@ func (h *BaseCommandHandler) gatherResponses(ctx context.Context, numCmds int) ( } // handleWatchNotification processes watch notification responses -func (h *BaseCommandHandler) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) error { +func (h *BaseCommandHandler) handleWatchNotification(diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse, watchLabel string) (interface{}, error) { fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint()) // if watch label is not empty, then this is the first response for the watch command @@ -459,55 +441,55 @@ func (h *BaseCommandHandler) handleWatchNotification(ctx context.Context, diceDB } if resp.EvalResponse.Error != nil { - return h.writeResponse(querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error)) + // This is a special case where error is returned as part of the watch response + return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Error), nil } - return h.writeResponse(querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result)) + return querymanager.GenericWatchResponse(firstRespElem, fingerprint, resp.EvalResponse.Result), nil } // handleUnsupportedCommand processes commands not in CommandsMeta -func (h *BaseCommandHandler) handleUnsupportedCommand(ctx context.Context, resp ops.StoreResponse) error { +func (h *BaseCommandHandler) handleUnsupportedCommand(resp ops.StoreResponse) (interface{}, error) { if resp.EvalResponse.Error != nil { - return h.writeResponse(resp.EvalResponse.Error) + return nil, resp.EvalResponse.Error } - return h.writeResponse(resp.EvalResponse.Result) + return resp.EvalResponse.Result, nil } // handleCommand processes commands based on their type -func (h *BaseCommandHandler) handleCommand(cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error { - var err error - +func (h *BaseCommandHandler) handleCommand(cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) (interface{}, error) { switch cmdMeta.CmdType { case SingleShard, Custom: if storeOp[0].EvalResponse.Error != nil { - err = h.writeResponse(storeOp[0].EvalResponse.Error) + return nil, storeOp[0].EvalResponse.Error } else { - err = h.writeResponse(storeOp[0].EvalResponse.Result) + return storeOp[0].EvalResponse.Result, nil } - if err == nil && h.wl != nil { - h.wl.LogCommand(diceDBCmd) - } case MultiShard, AllShard: - err = h.writeResponse(cmdMeta.composeResponse(storeOp...)) + return cmdMeta.composeResponse(storeOp...), nil - if err == nil && h.wl != nil { - h.wl.LogCommand(diceDBCmd) - } default: slog.Error("Unknown command type", slog.String("id", h.id), slog.String("command", diceDBCmd.Cmd), slog.Any("evalResp", storeOp)) - err = h.writeResponse(diceerrors.ErrInternalServer) + return nil, diceerrors.ErrInternalServer } - return err } -// writeResponse handles writing responses and logging errors -func (h *BaseCommandHandler) writeResponse(response interface{}) error { +func (h *BaseCommandHandler) handleError(err error) error { + if err != nil { + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + slog.Debug("Connection closed for io-thread", slog.String("id", h.id), slog.Any("error", err)) + return err + } + } + return fmt.Errorf("error writing response: %v", err) +} + +func (h *BaseCommandHandler) sendResponseToIOThread(response interface{}) { h.ioThreadWriteChan <- response - return nil } func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error {