Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
pberton committed Oct 15, 2024
1 parent 6c7e4b8 commit e783f9c
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 134 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ lint-install:
.PHONY: lint
lint:
echo "Running checks for service"
golangci-lint run ./...
golangci-lint run ./internal/...
golangci-lint run ./cmd/generic/...
golangci-lint run ./cmd/mysql/...
golangci-lint run ./cmd/postgres/...

.PHONY: sec-install
sec-install:
Expand Down
249 changes: 116 additions & 133 deletions cmd/generic/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
* 1. Start frontend(end client to proxy) TCP listeners.
* 2. Discover backend instance's endpoint via mapped proxy port.
* 2.a If backend instance is paused, starting the backend instance and holding frontend connections until backend instance is active.
* 3. Start backend(proxy to postgres instance) TCP channel.
* 3. Start backend(proxy to serverless resource instance) TCP channel.
* 4. Forward data from frontend to backend and forward response data from backend to frontend.
*/
func main() {
Expand All @@ -34,7 +34,7 @@ func main() {
for i := 0; i <= 9; i++ {
listenAddr := "0.0.0.0:3000" + strconv.Itoa(i) // #nosec G102

//Setup frontend TCP listener
// Setup frontend TCP listener
listener, err := net.ListenTCP("tcp", getResolvedAddresses(listenAddr))
if err != nil {
log.Printf("Failed to listen: %v", err)
Expand All @@ -44,12 +44,6 @@ func main() {
listeners = append(listeners, *listener)
}

defer func() {
for _, listener := range listeners {
listener.Close()
}
}()

// Initialize Omnistrate sidecar sidecarClient
var sidecarClient = sidecar.NewClient(context.Background())

Expand All @@ -68,13 +62,13 @@ func main() {
}

chExit := make(chan os.Signal, 1)
signal.Notify(chExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL)
select {
case <-chExit:
log.Println("Example EXITING...Bye.")
os.Exit(1)
signal.Notify(chExit, syscall.SIGINT, syscall.SIGTERM)
<-chExit
log.Println("EXITING...Bye.")
for _, listener := range listeners {
listener.Close()
}

os.Exit(1)
}

func handleClient(frontEndConnection *net.TCPConn, sidecarClient *sidecar.Client) {
Expand All @@ -88,171 +82,160 @@ func handleClient(frontEndConnection *net.TCPConn, sidecarClient *sidecar.Client
return
}

inputBuffer := make([]byte, 0xffff)
size, err := frontEndConnection.Read(inputBuffer)
if err != nil {
log.Printf("Failed to read from client: %v", err)
return
}

inputBuffer, err = getModifiedBuffer(inputBuffer[:size])
if err != nil {
log.Printf("%s\n", err)
return
}

// Check if the input is a psql connection
// First 8 bytes will be
// 00 00 00 08 04 d2 16 2f
if inputBuffer[3] != 0x08 &&
inputBuffer[4] != 0x04 &&
inputBuffer[5] != 0xd2 &&
inputBuffer[6] != 0x16 &&
inputBuffer[7] != 0x2f {
log.Printf("Not a psql connection")
return
}

var serverlessTargetPort string
var hostName string
var backendConnection *net.TCPConn
if os.Getenv("DRY_RUN") == "true" {
hostName = "localhost"
} else {
// Step 2: Discover backend instance's endpoint via mapped proxy port.
var err error
var response *http.Response
if response, err = sidecarClient.QueryBackendInstanceStatus(port); err != nil || response.StatusCode != 200 {
log.Printf("Failed to get backends endpoints")
hostName = "127.0.0.1"
serverlessTargetPort = "3306"
hostName = hostName + ":" + serverlessTargetPort
backendConnection, err = net.DialTCP("tcp", nil, getResolvedAddresses(hostName))
if err != nil {
log.Printf("Remote connection failed: %s", err)
return
}
} else {
retryCount := 0
for retryCount < 22 {
// Step 2: Discover backend instance's endpoint via mapped proxy port.
var err error
var response *http.Response
if response, err = sidecarClient.QueryBackendInstanceStatus(port); err != nil || response.StatusCode != 200 {
log.Printf("Failed to get backends endpoints")
return
}

var body []byte
if body, err = io.ReadAll(response.Body); err != nil {
log.Printf("Failed to read response body")
return
}
var body []byte
if body, err = io.ReadAll(response.Body); err != nil {
log.Printf("Failed to read response body")
return
}

responseBody := &sidecar.InstanceStatus{}
responseBody := &sidecar.InstanceStatus{}

if err = json.Unmarshal(body, &responseBody); err != nil {
log.Printf("Failed to unmarshal response body")
}
if err = json.Unmarshal(body, &responseBody); err != nil {
log.Printf("Failed to unmarshal response body")
}

log.Printf("Instance response: %s", responseBody)

switch responseBody.Status {
// Step 2a: if backend instance is paused, starting the backend instance and holding frontend connections until backend instance is active.
// In this example, we are using 22 retries with 15 seconds interval to check backend instance status.
case sidecar.PAUSED:
log.Printf("Instance is paused, waking up instance")
sidecarClient.StartInstance(responseBody.InstanceID)
retryCount := 0
for retryCount < 22 {
if response, err = sidecarClient.QueryBackendInstanceStatus(port); err != nil || response.StatusCode != 200 {
log.Printf("Failed to get backends endpoints %d times", retryCount)
log.Printf("Instance response: %s", responseBody)

switch responseBody.Status {
// Step 2a: if backend instance is paused, starting the backend instance and holding frontend connections until backend instance is active.
// In this example, we are using 22 retries with 15 seconds interval to check backend instance status.
case sidecar.PAUSED:
log.Printf("Instance is paused, waking up instance")
resp, err := sidecarClient.StartInstance(responseBody.InstanceID)
if err != nil {
log.Printf("Failed to start instance: %v", err)
return
}

var body []byte
if body, err = io.ReadAll(response.Body); err != nil {
log.Printf("Failed to read response body")
defer resp.Body.Close()
case sidecar.ACTIVE:
fallthrough
case sidecar.STARTING:
fallthrough
// If status unknown, still try to connect to avoid system glitch.
case sidecar.UNKNOWN:
serverlessTargetPort = os.Getenv("TARGET_PORT")
if serverlessTargetPort == "" {
log.Printf("Failed to get serverless target port")
return
}

if err = json.Unmarshal(body, &responseBody); err != nil {
log.Printf("Failed to unmarshal response body")
serverlessResourceKey := os.Getenv("SERVERLESS_RESOURCE_KEY")
if serverlessResourceKey == "" {
log.Printf("Failed to get serverless resource key")
return
}

log.Printf("Instance status: %s", responseBody.Status)

if responseBody.Status == sidecar.ACTIVE {
break
log.Printf("Instance is %s, trying to dial TCP.", responseBody.Status)

for _, sc := range responseBody.ServiceComponents {
if strings.Contains(sc.Alias, serverlessResourceKey) {
hostName = serverlessResourceKey + "." + responseBody.InstanceID
hostName = hostName + ":" + serverlessTargetPort
// Step 3: connect to backend serverless resource server
backendConnection, err = net.DialTCP("tcp", nil, getResolvedAddresses(hostName))
if err != nil {
log.Printf("Remote connection failed: %s", err)
}
break
}
}
time.Sleep(15 * time.Second)
retryCount++
default:
log.Printf("Instance is not in expected status %s, exiting...", responseBody.Status)
return
}
case sidecar.STARTING:
log.Printf("Instance is starting, waiting for instance to be available")
if _, err = frontEndConnection.Write([]byte("Instance is starting, waiting for instance to be available\n")); err != nil {
log.Printf("Failed to write to client: %v", err)

if responseBody.Status == sidecar.ACTIVE {
break
}
return
}

if responseBody.Status == sidecar.ACTIVE {
for _, sc := range responseBody.ServiceComponents {
if strings.Contains(sc.Alias, "postgres") {
hostName = sc.NodesEndpoints[0].Endpoint
break
}
if responseBody.Status != sidecar.STARTING && responseBody.Status != sidecar.PAUSED {
break
}
if hostName == "" {
log.Printf("Failed to get postgres endpoint")
return

if responseBody.Status == sidecar.STARTING && backendConnection != nil {
break
}
} else {
log.Printf("Instance is not active, exiting...")
return
}

defer func() {
if response != nil {
if closeErr := response.Body.Close(); closeErr != nil {
log.Printf("Failed to close response body: %v", closeErr)
time.Sleep(5 * time.Second)
retryCount++

defer func() {
if response != nil {
if closeErr := response.Body.Close(); closeErr != nil {
log.Printf("Failed to close response body: %v", closeErr)
}
}
}
}()
}()
}
}

// Backend port is depends on actual postgres port, in this example, we are using 5432
hostName = hostName + ":5432"
// Step 3: connect to backend postgres server
backendConnection, err := net.DialTCP("tcp", nil, getResolvedAddresses(hostName))
if err != nil {
log.Printf("Remote connection failed: %s", err)
if backendConnection == nil {
log.Printf("Didn't get backend connection established in time, exiting...")
return
}

// Step 4: Forward data from frontend to backend and forward response data from backend to frontend.
go handleIncomingConnection(frontEndConnection, backendConnection, inputBuffer)
go handleIncomingConnection(frontEndConnection, backendConnection)
go handleResponseConnection(backendConnection, frontEndConnection)

// TODO: Close frontend/backend connections
}

/**
* This function is used to forward data from frontend to backend. srcChannel is frontend connection, dstChannel is backend connection.
*/
func handleIncomingConnection(srcChannel, dstChannel *net.TCPConn, firstPacket []byte) {
func handleIncomingConnection(srcChannel, dstChannel *net.TCPConn) {
buff := make([]byte, 0xffff)
firstTime := true

for {
var b []byte
if !firstTime {
n, err := srcChannel.Read(buff)
if err != nil {
log.Printf("Read failed '%s'\n", err)
return
}

// Note that you can add any custom logic, like authentication, authorization
// before sending data to the backend postgres server.
b, err = getModifiedBuffer(buff[:n])
if err != nil {
log.Printf("%s\n", err)
n, err := srcChannel.Read(buff)
if err != nil {
if err == io.EOF {
err = dstChannel.Close()
if err != nil {
log.Printf("connection closed failed '%s'\n", err)
log.Printf("backend connection closed failed '%s'\n", err)
}
return
}
} else {
b = firstPacket
firstTime = false
log.Printf("Read failed '%s'\n", err)
return
}

// Note that you can add any custom logic, like authentication, authorization
// before sending data to the backend serverless resource server.
b, err = getModifiedBuffer(buff[:n])
if err != nil {
log.Printf("%s\n", err)
err = dstChannel.Close()
if err != nil {
log.Printf("connection closed failed '%s'\n", err)
}
return
}

_, err := dstChannel.Write(b)
_, err = dstChannel.Write(b)
if err != nil {
log.Printf("Write failed '%s'\n", err)
return
Expand All @@ -275,7 +258,7 @@ func handleResponseConnection(srcChannel, dstChannel *net.TCPConn) {
}
b := setResponseBuffer(buff[:n])

n, err = dstChannel.Write(b)
_, err = dstChannel.Write(b)
if err != nil {
log.Printf("Write failed '%s'\n", err)
return
Expand Down

0 comments on commit e783f9c

Please sign in to comment.