From 92079f7850869e811c1f0bb629eee203328dcc2b Mon Sep 17 00:00:00 2001 From: Joe Corall <jjc223@lehigh.edu> Date: Fri, 4 Oct 2024 14:45:01 -0400 Subject: [PATCH] Separate server and stomp subscriber logic (#44) --- main.go | 298 +-------------------------------------------------- main_test.go | 27 ++--- server.go | 108 +++++++++++++++++++ stomp.go | 206 +++++++++++++++++++++++++++++++++++ 4 files changed, 334 insertions(+), 305 deletions(-) create mode 100644 server.go create mode 100644 stomp.go diff --git a/main.go b/main.go index 92cfc15..3cef14d 100644 --- a/main.go +++ b/main.go @@ -1,311 +1,23 @@ package main import ( - "bytes" - "encoding/base64" - "fmt" "log/slog" - "math/rand" - "net" - "net/http" "os" - "os/signal" - "sync" - "syscall" - "time" - stomp "github.com/go-stomp/stomp/v3" scyllaridae "github.com/lehigh-university-libraries/scyllaridae/internal/config" - "github.com/lehigh-university-libraries/scyllaridae/pkg/api" ) -var ( - config *scyllaridae.ServerConfig -) - -func init() { - var err error - - config, err = scyllaridae.ReadConfig("scyllaridae.yml") +func main() { + config, err := scyllaridae.ReadConfig("scyllaridae.yml") if err != nil { slog.Error("Could not read YML", "err", err) os.Exit(1) } -} -func main() { if len(config.QueueMiddlewares) > 0 { - stopChan := make(chan os.Signal, 1) - signal.Notify(stopChan, os.Interrupt, syscall.SIGTERM) - - var wg sync.WaitGroup - - for _, middleware := range config.QueueMiddlewares { - wg.Add(1) - go func(middleware scyllaridae.QueueMiddleware) { - defer wg.Done() - messageChan := make(chan *stomp.Message, middleware.Consumers) - - // Start the specified number of worker goroutines - for i := 0; i < middleware.Consumers; i++ { - slog.Info("Adding consumer", "consumer", i) - go worker(messageChan, middleware) - } - - RecvStompMessages(middleware.QueueName, messageChan) - }(middleware) - } - - <-stopChan - slog.Info("Shutting down message listener") - } else { - // or make this an available API ala crayfish - http.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - fmt.Fprintln(w, "OK") - }) - http.HandleFunc("/", MessageHandler) - port := os.Getenv("PORT") - if port == "" { - port = "8080" - } - - slog.Info("Server listening", "port", port) - if err := http.ListenAndServe(":"+port, nil); err != nil { - panic(err) - } - } -} - -func MessageHandler(w http.ResponseWriter, r *http.Request) { - slog.Info(r.RequestURI, "method", r.Method, "ip", r.RemoteAddr, "proto", r.Proto) - - if r.Method != http.MethodGet { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - defer r.Body.Close() - - if r.Header.Get("Apix-Ldp-Resource") == "" && r.Header.Get("X-Islandora-Event") == "" { - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - - // Read the Alpaca message payload - auth := "" - if config.ForwardAuth { - auth = r.Header.Get("Authorization") - } - message, err := api.DecodeAlpacaMessage(r, auth) - if err != nil { - slog.Error("Error decoding alpaca message", "err", err) - http.Error(w, "Internal error", http.StatusInternalServerError) - return - } - - // Stream the file contents from the source URL - req, err := http.NewRequest("GET", message.Attachment.Content.SourceURI, nil) - if err != nil { - slog.Error("Error creating request to source", "source", message.Attachment.Content.SourceURI, "err", err) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - if config.ForwardAuth { - req.Header.Set("Authorization", auth) - } - sourceResp, err := http.DefaultClient.Do(req) - if err != nil { - slog.Error("Error fetching source file contents", "err", err) - http.Error(w, "Internal error", http.StatusInternalServerError) - return - } - defer sourceResp.Body.Close() - if sourceResp.StatusCode != http.StatusOK { - slog.Error("SourceURI sent a bad status code", "code", sourceResp.StatusCode, "uri", message.Attachment.Content.SourceURI) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - - cmd, err := scyllaridae.BuildExecCommand(message, config) - if err != nil { - slog.Error("Error building command", "err", err) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - cmd.Stdin = sourceResp.Body - - // Create a buffer to stream the output of the command - var stdErr bytes.Buffer - cmd.Stderr = &stdErr - - // send stdout to the ResponseWriter stream - cmd.Stdout = w - - slog.Info("Running command", "cmd", cmd.String()) - if err := cmd.Run(); err != nil { - slog.Error("Error running command", "cmd", cmd.String(), "err", stdErr.String()) - http.Error(w, "Internal error", http.StatusInternalServerError) - return - } -} - -func worker(messageChan <-chan *stomp.Message, middleware scyllaridae.QueueMiddleware) { - for msg := range messageChan { - handleMessage(msg, middleware) - } -} - -func RecvStompMessages(queueName string, messageChan chan<- *stomp.Message) { - attempt := 0 - maxAttempts := 30 - for attempt = 0; attempt < maxAttempts; attempt += 1 { - if err := connectAndSubscribe(queueName, messageChan); err != nil { - slog.Error("resubscribing", "queue", queueName, "error", err) - if err := retryWithExponentialBackoff(attempt, maxAttempts); err != nil { - slog.Error("Failed subscribing after too many failed attempts", "queue", queueName, "attempts", attempt) - return - } - } else { - // Subscription was successful - break - } - } -} - -func connectAndSubscribe(queueName string, messageChan chan<- *stomp.Message) error { - addr := os.Getenv("STOMP_SERVER_ADDR") - if addr == "" { - addr = "activemq:61613" - } - - c, err := net.Dial("tcp", addr) - if err != nil { - slog.Error("cannot connect to port", "err", err.Error()) - return err - } - tcpConn := c.(*net.TCPConn) - - err = tcpConn.SetKeepAlive(true) - if err != nil { - slog.Error("cannot set keepalive", "err", err.Error()) - return err - } - - err = tcpConn.SetKeepAlivePeriod(10 * time.Second) - if err != nil { - slog.Error("cannot set keepalive period", "err", err.Error()) - return err - } - - conn, err := stomp.Connect(tcpConn, stomp.ConnOpt.HeartBeat(10*time.Second, 0*time.Second)) - if err != nil { - slog.Error("cannot connect to stomp server", "err", err.Error()) - return err - } - defer func() { - err := conn.Disconnect() - if err != nil { - slog.Error("problem disconnecting from stomp server", "err", err) - } - }() - - sub, err := conn.Subscribe(queueName, stomp.AckAuto) - if err != nil { - slog.Error("cannot subscribe to queue", "queue", queueName, "err", err.Error()) - return err - } - defer func() { - if !sub.Active() { - return - } - err := sub.Unsubscribe() - if err != nil { - slog.Error("problem unsubscribing", "err", err) - } - }() - slog.Info("Server subscribed to", "queue", queueName) - - for msg := range sub.C { - if msg == nil || len(msg.Body) == 0 { - if !sub.Active() { - return fmt.Errorf("no longer subscribed to %s", queueName) - } - continue - } - messageChan <- msg // Send the message to the channel - } - - return nil -} - -func handleMessage(msg *stomp.Message, middleware scyllaridae.QueueMiddleware) { - req, err := http.NewRequest("GET", middleware.Url, nil) - if err != nil { - slog.Error("Error creating HTTP request", "url", middleware.Url, "err", err) - return - } - - req.Header.Set("X-Islandora-Event", base64.StdEncoding.EncodeToString(msg.Body)) - islandoraMessage, err := api.DecodeEventMessage(msg.Body) - if err != nil { - slog.Error("Unable to decode event message", "err", err) - return - } - - if middleware.ForwardAuth { - auth := msg.Header.Get("Authorization") - if auth != "" { - req.Header.Set("Authorization", auth) - } - } - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - slog.Error("Error sending HTTP GET request", "url", middleware.Url, "err", err) - return - } - defer resp.Body.Close() - - if resp.StatusCode >= 299 { - slog.Error("Failed to deliver message", "url", middleware.Url, "status", resp.StatusCode) - return - } - - if middleware.NoPut { - return - } - - putReq, err := http.NewRequest("PUT", islandoraMessage.Attachment.Content.DestinationURI, resp.Body) - if err != nil { - slog.Error("Error creating HTTP PUT request", "url", islandoraMessage.Attachment.Content.DestinationURI, "err", err) - return - } - - putReq.Header.Set("Authorization", msg.Header.Get("Authorization")) - putReq.Header.Set("Content-Type", islandoraMessage.Attachment.Content.DestinationMimeType) - putReq.Header.Set("Content-Location", islandoraMessage.Attachment.Content.FileUploadURI) - - // Send the PUT request - putResp, err := client.Do(putReq) - if err != nil { - slog.Error("Error sending HTTP PUT request", "url", islandoraMessage.Attachment.Content.DestinationURI, "err", err) - return - } - defer putResp.Body.Close() - - if putResp.StatusCode >= 299 { - slog.Error("Failed to PUT data", "url", islandoraMessage.Attachment.Content.DestinationURI, "status", putResp.StatusCode) + runStompSubscribers(config) } else { - slog.Info("Successfully PUT data to", "url", islandoraMessage.Attachment.Content.DestinationURI, "status", putResp.StatusCode) - } -} - -func retryWithExponentialBackoff(attempt int, maxAttempts int) error { - if attempt >= maxAttempts { - return fmt.Errorf("maximum retry attempts reached") + server := &Server{Config: config} + runHTTPServer(server) } - wait := time.Duration(rand.Intn(1<<attempt)) * time.Second - time.Sleep(wait) - return nil } diff --git a/main_test.go b/main_test.go index a4ad585..327a682 100644 --- a/main_test.go +++ b/main_test.go @@ -25,18 +25,19 @@ type Test struct { } func TestMessageHandler_MethodNotAllowed(t *testing.T) { + testConfig := &scyllaridae.ServerConfig{} + server := &Server{Config: testConfig} + req, err := http.NewRequest("POST", "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(MessageHandler) - + handler := http.HandlerFunc(server.MessageHandler) handler.ServeHTTP(rr, req) - if status := rr.Code; status != http.StatusMethodNotAllowed { - t.Errorf("handler returned wrong status code: got %v want %v", + t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusMethodNotAllowed) } } @@ -216,19 +217,20 @@ cmdByMimeType: destinationServer := createMockDestinationServer(t, tt.returnedBody) defer destinationServer.Close() - sourceServer := createMockSourceServer(t, tt.mimetype, tt.authHeader, destinationServer.URL) - defer sourceServer.Close() - os.Setenv("SCYLLARIDAE_YML", tt.yml) - // set the config based on tt.yml - config, err = scyllaridae.ReadConfig("") + config, err := scyllaridae.ReadConfig("") + + sourceServer := createMockSourceServer(t, config, tt.mimetype, tt.authHeader, destinationServer.URL) + defer sourceServer.Close() if err != nil { t.Fatalf("Could not read YML: %v", err) - os.Exit(1) } + // Create a Server instance with the test config + server := &Server{Config: config} + // Configure and start the main server - setupServer := httptest.NewServer(http.HandlerFunc(MessageHandler)) + setupServer := httptest.NewServer(http.HandlerFunc(server.MessageHandler)) defer setupServer.Close() // Send the mock message to the main server @@ -260,6 +262,7 @@ cmdByMimeType: } }) } + } func createMockDestinationServer(t *testing.T, content string) *httptest.Server { @@ -270,7 +273,7 @@ func createMockDestinationServer(t *testing.T, content string) *httptest.Server })) } -func createMockSourceServer(t *testing.T, mimetype, auth, content string) *httptest.Server { +func createMockSourceServer(t *testing.T, config *scyllaridae.ServerConfig, mimetype, auth, content string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if config.ForwardAuth && r.Header.Get("Authorization") != auth { w.WriteHeader(http.StatusUnauthorized) diff --git a/server.go b/server.go new file mode 100644 index 0000000..9bd25a2 --- /dev/null +++ b/server.go @@ -0,0 +1,108 @@ +package main + +import ( + "bytes" + "fmt" + "log/slog" + "net/http" + "os" + + scyllaridae "github.com/lehigh-university-libraries/scyllaridae/internal/config" + "github.com/lehigh-university-libraries/scyllaridae/pkg/api" +) + +type Server struct { + Config *scyllaridae.ServerConfig +} + +func runHTTPServer(server *Server) { + http.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "OK") + }) + + // Use the method as the handler + http.HandleFunc("/", server.MessageHandler) + + port := os.Getenv("PORT") + if port == "" { + port = "8080" + } + + slog.Info("Server listening", "port", port) + if err := http.ListenAndServe(":"+port, nil); err != nil { + panic(err) + } +} + +func (s *Server) MessageHandler(w http.ResponseWriter, r *http.Request) { + slog.Info(r.RequestURI, "method", r.Method, "ip", r.RemoteAddr, "proto", r.Proto) + + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + defer r.Body.Close() + + if r.Header.Get("Apix-Ldp-Resource") == "" && r.Header.Get("X-Islandora-Event") == "" { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Read the Alpaca message payload + auth := "" + if s.Config.ForwardAuth { + auth = r.Header.Get("Authorization") + } + message, err := api.DecodeAlpacaMessage(r, auth) + if err != nil { + slog.Error("Error decoding alpaca message", "err", err) + http.Error(w, "Internal error", http.StatusInternalServerError) + return + } + + // Stream the file contents from the source URL + req, err := http.NewRequest("GET", message.Attachment.Content.SourceURI, nil) + if err != nil { + slog.Error("Error creating request to source", "source", message.Attachment.Content.SourceURI, "err", err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + if s.Config.ForwardAuth { + req.Header.Set("Authorization", auth) + } + sourceResp, err := http.DefaultClient.Do(req) + if err != nil { + slog.Error("Error fetching source file contents", "err", err) + http.Error(w, "Internal error", http.StatusInternalServerError) + return + } + defer sourceResp.Body.Close() + if sourceResp.StatusCode != http.StatusOK { + slog.Error("SourceURI sent a bad status code", "code", sourceResp.StatusCode, "uri", message.Attachment.Content.SourceURI) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + cmd, err := scyllaridae.BuildExecCommand(message, s.Config) + if err != nil { + slog.Error("Error building command", "err", err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + cmd.Stdin = sourceResp.Body + + // Create a buffer to capture stderr + var stdErr bytes.Buffer + cmd.Stderr = &stdErr + + // Send stdout to the ResponseWriter stream + cmd.Stdout = w + + slog.Info("Running command", "cmd", cmd.String()) + if err := cmd.Run(); err != nil { + slog.Error("Error running command", "cmd", cmd.String(), "err", stdErr.String()) + http.Error(w, "Internal error", http.StatusInternalServerError) + return + } +} diff --git a/stomp.go b/stomp.go new file mode 100644 index 0000000..6fb3a1d --- /dev/null +++ b/stomp.go @@ -0,0 +1,206 @@ +package main + +import ( + "encoding/base64" + "fmt" + "log/slog" + "math/rand" + "net" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + stomp "github.com/go-stomp/stomp/v3" + scyllaridae "github.com/lehigh-university-libraries/scyllaridae/internal/config" + "github.com/lehigh-university-libraries/scyllaridae/pkg/api" +) + +func runStompSubscribers(config *scyllaridae.ServerConfig) { + stopChan := make(chan os.Signal, 1) + signal.Notify(stopChan, os.Interrupt, syscall.SIGTERM) + + var wg sync.WaitGroup + + for _, middleware := range config.QueueMiddlewares { + wg.Add(1) + go func(middleware scyllaridae.QueueMiddleware) { + defer wg.Done() + messageChan := make(chan *stomp.Message, middleware.Consumers) + + // Start the specified number of worker goroutines + for i := 0; i < middleware.Consumers; i++ { + slog.Info("Adding consumer", "consumer", i) + go worker(messageChan, middleware) + } + + RecvStompMessages(middleware.QueueName, messageChan) + }(middleware) + } + + <-stopChan + slog.Info("Shutting down message listener") +} + +func worker(messageChan <-chan *stomp.Message, middleware scyllaridae.QueueMiddleware) { + for msg := range messageChan { + handleMessage(msg, middleware) + } +} + +func RecvStompMessages(queueName string, messageChan chan<- *stomp.Message) { + attempt := 0 + maxAttempts := 30 + for attempt = 0; attempt < maxAttempts; attempt++ { + if err := connectAndSubscribe(queueName, messageChan); err != nil { + slog.Error("Resubscribing", "queue", queueName, "error", err) + if err := retryWithExponentialBackoff(attempt, maxAttempts); err != nil { + slog.Error("Failed subscribing after too many failed attempts", "queue", queueName, "attempts", attempt) + return + } + } else { + // Subscription was successful + break + } + } +} + +func connectAndSubscribe(queueName string, messageChan chan<- *stomp.Message) error { + addr := os.Getenv("STOMP_SERVER_ADDR") + if addr == "" { + addr = "activemq:61613" + } + + c, err := net.Dial("tcp", addr) + if err != nil { + slog.Error("Cannot connect to port", "err", err.Error()) + return err + } + tcpConn := c.(*net.TCPConn) + + err = tcpConn.SetKeepAlive(true) + if err != nil { + slog.Error("Cannot set keepalive", "err", err.Error()) + return err + } + + err = tcpConn.SetKeepAlivePeriod(10 * time.Second) + if err != nil { + slog.Error("Cannot set keepalive period", "err", err.Error()) + return err + } + + conn, err := stomp.Connect(tcpConn, stomp.ConnOpt.HeartBeat(10*time.Second, 0*time.Second)) + if err != nil { + slog.Error("Cannot connect to STOMP server", "err", err.Error()) + return err + } + defer func() { + err := conn.Disconnect() + if err != nil { + slog.Error("Problem disconnecting from STOMP server", "err", err) + } + }() + + sub, err := conn.Subscribe(queueName, stomp.AckAuto) + if err != nil { + slog.Error("Cannot subscribe to queue", "queue", queueName, "err", err.Error()) + return err + } + defer func() { + if !sub.Active() { + return + } + err := sub.Unsubscribe() + if err != nil { + slog.Error("Problem unsubscribing", "err", err) + } + }() + slog.Info("Server subscribed to", "queue", queueName) + + for msg := range sub.C { + if msg == nil || len(msg.Body) == 0 { + if !sub.Active() { + return fmt.Errorf("no longer subscribed to %s", queueName) + } + continue + } + messageChan <- msg // Send the message to the channel + } + + return nil +} + +func handleMessage(msg *stomp.Message, middleware scyllaridae.QueueMiddleware) { + req, err := http.NewRequest("GET", middleware.Url, nil) + if err != nil { + slog.Error("Error creating HTTP request", "url", middleware.Url, "err", err) + return + } + + req.Header.Set("X-Islandora-Event", base64.StdEncoding.EncodeToString(msg.Body)) + islandoraMessage, err := api.DecodeEventMessage(msg.Body) + if err != nil { + slog.Error("Unable to decode event message", "err", err) + return + } + + if middleware.ForwardAuth { + auth := msg.Header.Get("Authorization") + if auth != "" { + req.Header.Set("Authorization", auth) + } + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + slog.Error("Error sending HTTP GET request", "url", middleware.Url, "err", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode >= 299 { + slog.Error("Failed to deliver message", "url", middleware.Url, "status", resp.StatusCode) + return + } + + if middleware.NoPut { + return + } + + putReq, err := http.NewRequest("PUT", islandoraMessage.Attachment.Content.DestinationURI, resp.Body) + if err != nil { + slog.Error("Error creating HTTP PUT request", "url", islandoraMessage.Attachment.Content.DestinationURI, "err", err) + return + } + + putReq.Header.Set("Authorization", msg.Header.Get("Authorization")) + putReq.Header.Set("Content-Type", islandoraMessage.Attachment.Content.DestinationMimeType) + putReq.Header.Set("Content-Location", islandoraMessage.Attachment.Content.FileUploadURI) + + // Send the PUT request + putResp, err := client.Do(putReq) + if err != nil { + slog.Error("Error sending HTTP PUT request", "url", islandoraMessage.Attachment.Content.DestinationURI, "err", err) + return + } + defer putResp.Body.Close() + + if putResp.StatusCode >= 299 { + slog.Error("Failed to PUT data", "url", islandoraMessage.Attachment.Content.DestinationURI, "status", putResp.StatusCode) + } else { + slog.Info("Successfully PUT data to", "url", islandoraMessage.Attachment.Content.DestinationURI, "status", putResp.StatusCode) + } +} + +func retryWithExponentialBackoff(attempt int, maxAttempts int) error { + if attempt >= maxAttempts { + return fmt.Errorf("maximum retry attempts reached") + } + wait := time.Duration(rand.Intn(1<<attempt)) * time.Second + time.Sleep(wait) + return nil +}