diff --git a/Makefile b/Makefile index 8b80520..936e258 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,10 @@ lint-install: .PHONY: lint lint: echo "Running checks for service" - golangci-lint run ./... + golangci-lint run ./internal/... + golangci-lint run ./cmd/generic/... + golangci-lint run ./cmd/mysql/... + golangci-lint run ./cmd/postgres/... .PHONY: sec-install sec-install: diff --git a/cmd/generic/cmd.go b/cmd/generic/cmd.go index 0fb4629..05352b6 100644 --- a/cmd/generic/cmd.go +++ b/cmd/generic/cmd.go @@ -22,7 +22,7 @@ import ( * 1. Start frontend(end client to proxy) TCP listeners. * 2. Discover backend instance's endpoint via mapped proxy port. * 2.a If backend instance is paused, starting the backend instance and holding frontend connections until backend instance is active. - * 3. Start backend(proxy to postgres instance) TCP channel. + * 3. Start backend(proxy to serverless resource instance) TCP channel. * 4. Forward data from frontend to backend and forward response data from backend to frontend. */ func main() { @@ -34,7 +34,7 @@ func main() { for i := 0; i <= 9; i++ { listenAddr := "0.0.0.0:3000" + strconv.Itoa(i) // #nosec G102 - //Setup frontend TCP listener + // Setup frontend TCP listener listener, err := net.ListenTCP("tcp", getResolvedAddresses(listenAddr)) if err != nil { log.Printf("Failed to listen: %v", err) @@ -44,12 +44,6 @@ func main() { listeners = append(listeners, *listener) } - defer func() { - for _, listener := range listeners { - listener.Close() - } - }() - // Initialize Omnistrate sidecar sidecarClient var sidecarClient = sidecar.NewClient(context.Background()) @@ -68,13 +62,13 @@ func main() { } chExit := make(chan os.Signal, 1) - signal.Notify(chExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL) - select { - case <-chExit: - log.Println("Example EXITING...Bye.") - os.Exit(1) + signal.Notify(chExit, syscall.SIGINT, syscall.SIGTERM) + <-chExit + log.Println("EXITING...Bye.") + for _, listener := range listeners { + listener.Close() } - + os.Exit(1) } func handleClient(frontEndConnection *net.TCPConn, sidecarClient *sidecar.Client) { @@ -88,171 +82,160 @@ func handleClient(frontEndConnection *net.TCPConn, sidecarClient *sidecar.Client return } - inputBuffer := make([]byte, 0xffff) - size, err := frontEndConnection.Read(inputBuffer) - if err != nil { - log.Printf("Failed to read from client: %v", err) - return - } - - inputBuffer, err = getModifiedBuffer(inputBuffer[:size]) - if err != nil { - log.Printf("%s\n", err) - return - } - - // Check if the input is a psql connection - // First 8 bytes will be - // 00 00 00 08 04 d2 16 2f - if inputBuffer[3] != 0x08 && - inputBuffer[4] != 0x04 && - inputBuffer[5] != 0xd2 && - inputBuffer[6] != 0x16 && - inputBuffer[7] != 0x2f { - log.Printf("Not a psql connection") - return - } - + var serverlessTargetPort string var hostName string + var backendConnection *net.TCPConn if os.Getenv("DRY_RUN") == "true" { - hostName = "localhost" - } else { - // Step 2: Discover backend instance's endpoint via mapped proxy port. var err error - var response *http.Response - if response, err = sidecarClient.QueryBackendInstanceStatus(port); err != nil || response.StatusCode != 200 { - log.Printf("Failed to get backends endpoints") + hostName = "127.0.0.1" + serverlessTargetPort = "3306" + hostName = hostName + ":" + serverlessTargetPort + backendConnection, err = net.DialTCP("tcp", nil, getResolvedAddresses(hostName)) + if err != nil { + log.Printf("Remote connection failed: %s", err) return } + } else { + retryCount := 0 + for retryCount < 22 { + // Step 2: Discover backend instance's endpoint via mapped proxy port. + var err error + var response *http.Response + if response, err = sidecarClient.QueryBackendInstanceStatus(port); err != nil || response.StatusCode != 200 { + log.Printf("Failed to get backends endpoints") + return + } - var body []byte - if body, err = io.ReadAll(response.Body); err != nil { - log.Printf("Failed to read response body") - return - } + var body []byte + if body, err = io.ReadAll(response.Body); err != nil { + log.Printf("Failed to read response body") + return + } - responseBody := &sidecar.InstanceStatus{} + responseBody := &sidecar.InstanceStatus{} - if err = json.Unmarshal(body, &responseBody); err != nil { - log.Printf("Failed to unmarshal response body") - } + if err = json.Unmarshal(body, &responseBody); err != nil { + log.Printf("Failed to unmarshal response body") + } - log.Printf("Instance response: %s", responseBody) - - switch responseBody.Status { - // Step 2a: if backend instance is paused, starting the backend instance and holding frontend connections until backend instance is active. - // In this example, we are using 22 retries with 15 seconds interval to check backend instance status. - case sidecar.PAUSED: - log.Printf("Instance is paused, waking up instance") - sidecarClient.StartInstance(responseBody.InstanceID) - retryCount := 0 - for retryCount < 22 { - if response, err = sidecarClient.QueryBackendInstanceStatus(port); err != nil || response.StatusCode != 200 { - log.Printf("Failed to get backends endpoints %d times", retryCount) + log.Printf("Instance response: %s", responseBody) + + switch responseBody.Status { + // Step 2a: if backend instance is paused, starting the backend instance and holding frontend connections until backend instance is active. + // In this example, we are using 22 retries with 15 seconds interval to check backend instance status. + case sidecar.PAUSED: + log.Printf("Instance is paused, waking up instance") + resp, err := sidecarClient.StartInstance(responseBody.InstanceID) + if err != nil { + log.Printf("Failed to start instance: %v", err) return } - - var body []byte - if body, err = io.ReadAll(response.Body); err != nil { - log.Printf("Failed to read response body") + defer resp.Body.Close() + case sidecar.ACTIVE: + fallthrough + case sidecar.STARTING: + fallthrough + // If status unknown, still try to connect to avoid system glitch. + case sidecar.UNKNOWN: + serverlessTargetPort = os.Getenv("TARGET_PORT") + if serverlessTargetPort == "" { + log.Printf("Failed to get serverless target port") return } - if err = json.Unmarshal(body, &responseBody); err != nil { - log.Printf("Failed to unmarshal response body") + serverlessResourceKey := os.Getenv("SERVERLESS_RESOURCE_KEY") + if serverlessResourceKey == "" { + log.Printf("Failed to get serverless resource key") return } - log.Printf("Instance status: %s", responseBody.Status) - - if responseBody.Status == sidecar.ACTIVE { - break + log.Printf("Instance is %s, trying to dial TCP.", responseBody.Status) + + for _, sc := range responseBody.ServiceComponents { + if strings.Contains(sc.Alias, serverlessResourceKey) { + hostName = serverlessResourceKey + "." + responseBody.InstanceID + hostName = hostName + ":" + serverlessTargetPort + // Step 3: connect to backend serverless resource server + backendConnection, err = net.DialTCP("tcp", nil, getResolvedAddresses(hostName)) + if err != nil { + log.Printf("Remote connection failed: %s", err) + } + break + } } - time.Sleep(15 * time.Second) - retryCount++ + default: + log.Printf("Instance is not in expected status %s, exiting...", responseBody.Status) + return } - case sidecar.STARTING: - log.Printf("Instance is starting, waiting for instance to be available") - if _, err = frontEndConnection.Write([]byte("Instance is starting, waiting for instance to be available\n")); err != nil { - log.Printf("Failed to write to client: %v", err) + + if responseBody.Status == sidecar.ACTIVE { + break } - return - } - if responseBody.Status == sidecar.ACTIVE { - for _, sc := range responseBody.ServiceComponents { - if strings.Contains(sc.Alias, "postgres") { - hostName = sc.NodesEndpoints[0].Endpoint - break - } + if responseBody.Status != sidecar.STARTING && responseBody.Status != sidecar.PAUSED { + break } - if hostName == "" { - log.Printf("Failed to get postgres endpoint") - return + + if responseBody.Status == sidecar.STARTING && backendConnection != nil { + break } - } else { - log.Printf("Instance is not active, exiting...") - return - } - defer func() { - if response != nil { - if closeErr := response.Body.Close(); closeErr != nil { - log.Printf("Failed to close response body: %v", closeErr) + time.Sleep(5 * time.Second) + retryCount++ + + defer func() { + if response != nil { + if closeErr := response.Body.Close(); closeErr != nil { + log.Printf("Failed to close response body: %v", closeErr) + } } - } - }() + }() + } } - // Backend port is depends on actual postgres port, in this example, we are using 5432 - hostName = hostName + ":5432" - // Step 3: connect to backend postgres server - backendConnection, err := net.DialTCP("tcp", nil, getResolvedAddresses(hostName)) - if err != nil { - log.Printf("Remote connection failed: %s", err) + if backendConnection == nil { + log.Printf("Didn't get backend connection established in time, exiting...") return } // Step 4: Forward data from frontend to backend and forward response data from backend to frontend. - go handleIncomingConnection(frontEndConnection, backendConnection, inputBuffer) + go handleIncomingConnection(frontEndConnection, backendConnection) go handleResponseConnection(backendConnection, frontEndConnection) - - // TODO: Close frontend/backend connections } /** * This function is used to forward data from frontend to backend. srcChannel is frontend connection, dstChannel is backend connection. */ -func handleIncomingConnection(srcChannel, dstChannel *net.TCPConn, firstPacket []byte) { +func handleIncomingConnection(srcChannel, dstChannel *net.TCPConn) { buff := make([]byte, 0xffff) - firstTime := true for { var b []byte - if !firstTime { - n, err := srcChannel.Read(buff) - if err != nil { - log.Printf("Read failed '%s'\n", err) - return - } - - // Note that you can add any custom logic, like authentication, authorization - // before sending data to the backend postgres server. - b, err = getModifiedBuffer(buff[:n]) - if err != nil { - log.Printf("%s\n", err) + n, err := srcChannel.Read(buff) + if err != nil { + if err == io.EOF { err = dstChannel.Close() if err != nil { - log.Printf("connection closed failed '%s'\n", err) + log.Printf("backend connection closed failed '%s'\n", err) } - return } - } else { - b = firstPacket - firstTime = false + log.Printf("Read failed '%s'\n", err) + return + } + + // Note that you can add any custom logic, like authentication, authorization + // before sending data to the backend serverless resource server. + b, err = getModifiedBuffer(buff[:n]) + if err != nil { + log.Printf("%s\n", err) + err = dstChannel.Close() + if err != nil { + log.Printf("connection closed failed '%s'\n", err) + } + return } - _, err := dstChannel.Write(b) + _, err = dstChannel.Write(b) if err != nil { log.Printf("Write failed '%s'\n", err) return @@ -275,7 +258,7 @@ func handleResponseConnection(srcChannel, dstChannel *net.TCPConn) { } b := setResponseBuffer(buff[:n]) - n, err = dstChannel.Write(b) + _, err = dstChannel.Write(b) if err != nil { log.Printf("Write failed '%s'\n", err) return