Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traefik) allow websocket traffic #275

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,44 @@ import (
"time"

"github.com/gavv/httpexpect/v2"
"github.com/gorilla/websocket"
)

func Test_Blocking_WebSocket(t *testing.T) {
wsURL := "ws://localhost:8080/echo"

conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatal("WebSocket connection failed:", err)
}
defer conn.Close()

done := make(chan bool)
go func() {
defer close(done)
endTime := time.Now().Add(1 * time.Minute)
for time.Now().Before(endTime) {
if err := conn.WriteMessage(websocket.TextMessage, []byte("Hello WebSocket")); err != nil {
t.Error("Write error:", err)
return
}
_, message, err := conn.ReadMessage()
if err != nil {
t.Error("Read error:", err)
return
}
t.Logf("Received: %s", message)
time.Sleep(20 * time.Second)
}
}()

select {
case <-done:
case <-time.After(time.Minute * 2 + 30*time.Second):
t.Fatal("Test did not complete in time")
}
}

func Test_Dynamic(t *testing.T) {
e := httpexpect.Default(t, "http://localhost:8080/dynamic/")

Expand Down
4 changes: 2 additions & 2 deletions plugins/traefik/e2e/docker/dynamic-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ http:
service: "whoami"

whoami-blocking:
rule: PathPrefix(`/blocking/whoami`)
rule: PathPrefix(`/blocking/whoami`) || PathPrefix(`/echo`)
entryPoints:
- "http"
middlewares:
Expand Down Expand Up @@ -48,4 +48,4 @@ http:
- "http"
middlewares:
- healthy@docker
service: "nginx"
service: "nginx"
7 changes: 5 additions & 2 deletions plugins/traefik/e2e/docker/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ destroy_docker_classic() {

run_docker_classic_test() {
echo "Running Docker Classic Test: $1"
TIMEOUT=${2:-30s}
echo "TimeOut set to ${TIMEOUT}"
prepare_docker_classic
sleep 2
go clean -testcache
if ! go test -count=1 -tags e2e -timeout 30s -run ^${1}$ github.com/acouvreur/sablier/e2e; then
if ! go test -count=1 -tags e2e -timeout ${TIMEOUT} -run ^${1}$ github.com/acouvreur/sablier/e2e; then
errors=1
docker compose -f ${DOCKER_COMPOSE_FILE} -p ${DOCKER_COMPOSE_PROJECT_NAME} logs sablier traefik
fi
Expand All @@ -35,5 +37,6 @@ run_docker_classic_test Test_Dynamic
run_docker_classic_test Test_Blocking
run_docker_classic_test Test_Multiple
run_docker_classic_test Test_Healthy
run_docker_classic_test Test_Blocking_WebSocket 3m

exit $errors
exit $errors
2 changes: 1 addition & 1 deletion plugins/traefik/e2e/docker_swarm/docker-stack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ services:
- traefik.http.routers.whoami-dynamic.rule=PathPrefix(`/dynamic/whoami`)
- traefik.http.routers.whoami-dynamic.service=whoami
- traefik.http.routers.whoami-blocking.middlewares=blocking@docker
- traefik.http.routers.whoami-blocking.rule=PathPrefix(`/blocking/whoami`)
- traefik.http.routers.whoami-blocking.rule=PathPrefix(`/blocking/whoami`) || PathPrefix(`/echo`)
- traefik.http.routers.whoami-blocking.service=whoami
- traefik.http.routers.whoami-multiple.middlewares=multiple@docker
- traefik.http.routers.whoami-multiple.rule=PathPrefix(`/multiple/whoami`)
Expand Down
7 changes: 5 additions & 2 deletions plugins/traefik/e2e/docker_swarm/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ destroy_docker_swarm() {

run_docker_swarm_test() {
echo "Running Docker Swarm Test: $1"
TIMEOUT=${2:-30s}
echo "TimeOut set to ${TIMEOUT}"
prepare_docker_stack
sleep 10
go clean -testcache
if ! go test -count=1 -tags e2e -timeout 30s -run ^${1}$ github.com/acouvreur/sablier/e2e; then
if ! go test -count=1 -tags e2e -timeout ${TIMEOUT} -run ^${1}$ github.com/acouvreur/sablier/e2e; then
errors=1
docker service logs ${DOCKER_STACK_NAME}_sablier
docker service logs ${DOCKER_STACK_NAME}_traefik
Expand All @@ -47,5 +49,6 @@ run_docker_swarm_test Test_Dynamic
run_docker_swarm_test Test_Blocking
run_docker_swarm_test Test_Multiple
run_docker_swarm_test Test_Healthy
run_docker_swarm_test Test_Blocking_WebSocket 3m

exit $errors
exit $errors
9 changes: 8 additions & 1 deletion plugins/traefik/e2e/kubernetes/manifests/deployment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ spec:
name: whoami-service
port:
number: 80
- path: /echo
pathType: Prefix
backend:
service:
name: whoami-service
port:
number: 80
---
apiVersion: networking.k8s.io/v1
kind: Ingress
Expand Down Expand Up @@ -218,4 +225,4 @@ spec:
service:
name: nginx-service
port:
number: 80
number: 80
5 changes: 4 additions & 1 deletion plugins/traefik/e2e/kubernetes/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ destroy_stateful_set() {

run_kubernetes_deployment_test() {
echo "---- Running Kubernetes Test: $1 ----"
TIMEOUT=${2:-30s}
echo "TimeOut set to ${TIMEOUT}"
prepare_deployment
sleep 10
go clean -testcache
if ! go test -count=1 -tags e2e -timeout 30s -run ^${1}$ github.com/acouvreur/sablier/e2e; then
if ! go test -count=1 -tags e2e -timeout ${TIMEOUT} -run ^${1}$ github.com/acouvreur/sablier/e2e; then
errors=1
kubectl -n kube-system logs deployments/sablier-deployment
kubectl -n kube-system logs deployments/traefik
Expand All @@ -68,5 +70,6 @@ run_kubernetes_deployment_test Test_Dynamic
run_kubernetes_deployment_test Test_Blocking
run_kubernetes_deployment_test Test_Multiple
run_kubernetes_deployment_test Test_Healthy
run_kubernetes_deployment_test Test_Blocking_WebSocket 3m

exit $errors
158 changes: 156 additions & 2 deletions plugins/traefik/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,32 @@ import (
"net"
"net/http"
"net/http/httptrace"
"os"
"strings"
"time"
)

type WebSocketEvent int

const (
WebSocketRead WebSocketEvent = iota
WebSocketWrite
WebSocketClose
)

var wsEventChan = make(chan WebSocketEvent)

type SablierMiddleware struct {
client *http.Client
request *http.Request
next http.Handler
useRedirect bool
config *Config
}

// New function creates the configuration
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
req, err := config.BuildRequest(name)

if err != nil {
return nil, err
}
Expand All @@ -31,11 +44,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
next: next,
// there is no way to make blocking work in traefik without redirect so let's make it default
useRedirect: config.Blocking != nil,
config: config,
}, nil
}

func (sm *SablierMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
sablierRequest := sm.request.Clone(context.TODO())
fmt.Println("=== sablierRequest", sablierRequest)

resp, err := sm.client.Do(sablierRequest)
if err != nil {
Expand All @@ -47,6 +62,13 @@ func (sm *SablierMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request

conditonalResponseWriter := newResponseWriter(rw)

if isWebsocketRequest(req) {
// FIXME dynamic make no sense for websocket since client return error
fmt.Println("=== websocket request")
go monitorWebSocketActivity(sablierRequest, sm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this approach also. You create go routine per connection. I wondered if this would scale up. I didn't check details though.

Copy link
Author

@DrummyFloyd DrummyFloyd Apr 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should scale correctly , maybe in next iteration , we should add Mutex for security , but not really conformtable with it

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote a proxy for Sablier to handle TCP only services and had a similar problem: How to prevent the service from shutting down, if there are only long running connections.

I don't check if there is actual activity on the connection, but I have a single task/process that will, as long as there is at least 1 active connection, ping the Sablier endpoint in regular intervals. Maybe something like this would work here, too:

https://github.com/vbrandl/sablier-proxy/blob/main/src/main.rs#L154

This way, you don't have to spawn another goroutine per connection, but there will be exactly 1 monitoring task/routine, which might scale better if there are many connections.

conditonalResponseWriter.websocket = true
}

useRedirect := false

if resp.Header.Get("X-Sablier-Session-Status") == "ready" {
Expand Down Expand Up @@ -97,6 +119,7 @@ type responseWriter struct {
responseWriter http.ResponseWriter
headers http.Header
ready bool
websocket bool
}

func (r *responseWriter) Header() http.Header {
Expand All @@ -114,6 +137,11 @@ func (r *responseWriter) Write(buf []byte) (int, error) {
}

func (r *responseWriter) WriteHeader(code int) {
// TODO need to check for code 101? Is it possible that after error connection won't be websocket
if code != http.StatusSwitchingProtocols {
r.websocket = false
}
fmt.Println("=== code", code)
if r.ready == false && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that r.ready
Expand All @@ -136,11 +164,137 @@ func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !ok {
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.responseWriter)
}
return hijacker.Hijack()
if r.websocket {
fmt.Println("=== hijack for websocket")
conn, bufio, err := hijacker.Hijack()
return newConnWrapper(conn), bufio, err
} else {
return hijacker.Hijack()
}
}

func (r *responseWriter) Flush() {
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

func isWebsocketRequest(req *http.Request) bool {
return containsHeader(req, "Connection", "upgrade") && containsHeader(req, "Upgrade", "websocket")
}

func containsHeader(req *http.Request, name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
if value == strings.ToLower(strings.TrimSpace(item)) {
return true
}
}
return false
}

func newConnWrapper(c net.Conn) *conn {
return &conn{
conn: c,
}
}

type conn struct {
conn net.Conn
}

// LocalAddr implements net.Conn.
func (c *conn) LocalAddr() net.Addr {
panic("unimplemented")
}

// RemoteAddr implements net.Conn.
func (c *conn) RemoteAddr() net.Addr {
panic("unimplemented")
}

// SetDeadline implements net.Conn.
func (c *conn) SetDeadline(t time.Time) error {
panic("unimplemented")
}

// SetReadDeadline implements net.Conn.
func (c *conn) SetReadDeadline(t time.Time) error {
panic("unimplemented")
}

// SetWriteDeadline implements net.Conn.
func (c *conn) SetWriteDeadline(t time.Time) error {
panic("unimplemented")
}

func (c *conn) Read(b []byte) (n int, err error) {
n, err = c.conn.Read(b)
if err == nil {
wsEventChan <- WebSocketRead // Notify about the read operation
}
return
}

func (c *conn) Write(b []byte) (n int, err error) {
n, err = c.conn.Write(b)
if err == nil {
wsEventChan <- WebSocketWrite // Notify about the write operation
}
return
}

func (c *conn) Close() error {
err := c.conn.Close()
wsEventChan <- WebSocketClose // Notify about the close operation
return err
}

func monitorWebSocketActivity(sablierRequest *http.Request, sm *SablierMiddleware) {
duration, err := time.ParseDuration(sm.config.SessionDuration)
if err != nil {
fmt.Fprintln(os.Stdout, []any{`Error parsing sessionDuration: %v`, err}...)
return
}
alertTime := duration - (duration * 5 / 100) // Calculate alert time at 95% of the total duration
DrummyFloyd marked this conversation as resolved.
Show resolved Hide resolved
alertTicker := time.NewTicker(alertTime)
defer alertTicker.Stop()

// Active flag to determine if the ticker should be reset
activeDuringAlert := false

for {
select {
case event := <-wsEventChan:
switch event {
case WebSocketRead, WebSocketWrite:
activeDuringAlert = true // Mark that there was activity during the alert period

case WebSocketClose:
fmt.Println("WebSocket closed")
_, err := sm.client.Do(sablierRequest)
if err != nil {
fmt.Println("Error in sending request to update websocket alive to sablier", err)
}
alertTicker.Stop()
activeDuringAlert = false // Do not reset ticker on close
DrummyFloyd marked this conversation as resolved.
Show resolved Hide resolved
}

case <-alertTicker.C:
if activeDuringAlert {
fmt.Println("Continuing activity detected, resetting ticker")
_, err := sm.client.Do(sablierRequest)
if err != nil {
fmt.Println("Error in sending request to update websocket alive to sablier", err)
}

alertTicker.Reset(alertTime) // Reset the ticker for another period
activeDuringAlert = false // Reset the activity flag for the next period

} else {
fmt.Println("No activity detected within the alert time,will scaling down")
return
}
}
}
}
Loading