Skip to content

Commit

Permalink
chore: Simplify goroutines structure (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
gruyaume authored Feb 1, 2024
1 parent 68c8b7f commit d72319d
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 100 deletions.
83 changes: 29 additions & 54 deletions cmd/sepp/sepp.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package main

import (
"context"
"flag"
"os"
"os/signal"
"syscall"
"time"

"log"
"sync"
"time"

"github.com/dot-5g/sepp/config"
"github.com/dot-5g/sepp/internal/n32"
Expand All @@ -23,66 +19,45 @@ func init() {

func main() {
flag.Parse()

var wg sync.WaitGroup
conf, err := config.LoadConfiguration(configFilePath)
if err != nil {
log.Fatalf("Failed to read config file: %s", err)
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

wg.Add(1)
go func() {
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
cancel()
log.Println("server manually stopped")
defer wg.Done()
n32.StartServer(conf.SEPP.Local.N32.GetAddress(), conf.SEPP.Local.N32.TLS.Cert, conf.SEPP.Local.N32.TLS.Key, conf.SEPP.Local.N32.TLS.CA, conf.SEPP.Local.N32.FQDN)
}()

address := conf.SEPP.Local.N32.Host + ":" + conf.SEPP.Local.N32.Port
go n32.StartServer(ctx, address, conf.SEPP.Local.N32.TLS.Cert, conf.SEPP.Local.N32.TLS.Key, conf.SEPP.Local.N32.TLS.CA, conf.SEPP.Local.N32.FQDN)

remoteURL := conf.SEPP.Remote.URL
if remoteURL != "" {
seppClient := n32.NewClient(conf.SEPP.Local.N32.TLS.Cert, conf.SEPP.Local.N32.TLS.Key, conf.SEPP.Local.N32.TLS.CA)
reqData := n32.SecNegotiateReqData{
Sender: n32.FQDN("testSender"),
SupportedSecCapabilityList: []n32.SecurityCapability{n32.TLS},
}

exchangeCapability(remoteURL, conf.SEPP.Remote.TLS)
wg.Add(1)
go func() {
for {
select {
case <-ctx.Done():
return
default:
if cap, err := seppClient.POSTExchangeCapability(ctx, remoteURL, reqData); err != nil {
log.Printf("failed to exchange capability: %s", err)
waitOrCancel(ctx, 30*time.Second)
} else if cap.SelectedSecCapability == n32.TLS {
log.Println("security exchange successful, starting SBI server...")
sbi.StartServer(ctx, conf)
return
} else {
log.Printf("unsupported capability: %v", cap)
waitOrCancel(ctx, 30*time.Second)
}
}
}
defer wg.Done()
sbi.StartServer(conf)
}()
} else {
log.Println("no remote URL specified, not starting SBI server...")
}

<-ctx.Done() // Wait here until the context is canceled
wg.Wait()
}

func waitOrCancel(ctx context.Context, duration time.Duration) {
select {
case <-ctx.Done():
return
case <-time.After(duration):
return
func exchangeCapability(remoteURL string, n32TLSConf config.TLS) {
seppClient := n32.NewClient(n32TLSConf.Cert, n32TLSConf.Key, n32TLSConf.CA)
reqData := n32.SecNegotiateReqData{
Sender: n32.FQDN("testSender"),
SupportedSecCapabilityList: []n32.SecurityCapability{n32.TLS},
}
for {
cap, err := seppClient.POSTExchangeCapability(remoteURL, reqData)
if err == nil && cap.SelectedSecCapability == n32.TLS {
log.Printf("Successfully exchanged capability: %s", cap.SelectedSecCapability)
break
}
if err != nil {
log.Printf("Failed to exchange capability: %s", err)
} else {
log.Printf("Failed to exchange capability: expected %s, got %s", n32.TLS, cap)
}
time.Sleep(30 * time.Second)
}
}
50 changes: 26 additions & 24 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,42 @@ import (
"gopkg.in/yaml.v2"
)

type TLS struct {
Cert string `yaml:"cert"`
Key string `yaml:"key"`
CA string `yaml:"ca"`
}

type N32 struct {
FQDN string `yaml:"fqdn"`
Host string `yaml:"host"`
Port string `yaml:"port"`
TLS TLS `yaml:"tls"`
}

type SBI struct {
Host string `yaml:"host"`
Port string `yaml:"port"`
TLS TLS `yaml:"tls"`
}

type Config struct {
SEPP struct {
Local struct {
N32 struct {
FQDN string `yaml:"fqdn"`
Host string `yaml:"host"`
Port string `yaml:"port"`
TLS struct {
Cert string `yaml:"cert"`
Key string `yaml:"key"`
CA string `yaml:"ca"`
} `yaml:"tls"`
} `yaml:"n32"`
SBI struct {
Host string `yaml:"host"`
Port string `yaml:"port"`
TLS struct {
Cert string `yaml:"cert"`
Key string `yaml:"key"`
CA string `yaml:"ca"`
} `yaml:"tls"`
} `yaml:"sbi"`
N32 N32 `yaml:"n32"`
SBI SBI `yaml:"sbi"`
} `yaml:"local"`
Remote struct {
URL string `yaml:"url"`
TLS struct {
Cert string `yaml:"cert"`
Key string `yaml:"key"`
CA string `yaml:"ca"`
} `yaml:"tls"`
TLS TLS `yaml:"tls"`
} `yaml:"remote"`
} `yaml:"sepp"`
}

func (n32 N32) GetAddress() string {
return n32.Host + ":" + n32.Port
}

func ReadConfig(reader io.Reader) (*Config, error) {
var config Config

Expand Down
5 changes: 2 additions & 3 deletions internal/n32/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package n32

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
Expand Down Expand Up @@ -43,15 +42,15 @@ func NewClient(certPath string, keyPath string, caCertPath string) *Client {
}
}

func (c *Client) POSTExchangeCapability(ctx context.Context, remoteURL string, secNegotiateReqData SecNegotiateReqData) (SecNegotiateRspData, error) {
func (c *Client) POSTExchangeCapability(remoteURL string, secNegotiateReqData SecNegotiateReqData) (SecNegotiateRspData, error) {
secNegotiateRspData := SecNegotiateRspData{}
jsonData, err := json.Marshal(secNegotiateReqData)
if err != nil {
return secNegotiateRspData, err
}

endpoint := remoteURL + "/n32c-handshake/v1/exchange-capability"
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return secNegotiateRspData, err
}
Expand Down
11 changes: 1 addition & 10 deletions internal/n32/server.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package n32

import (
"context"
"crypto/tls"
"crypto/x509"
"log"
Expand Down Expand Up @@ -35,7 +34,7 @@ func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
}
}

func StartServer(ctx context.Context, address string, serverCertPath string, serverKeyPath string, caCertPath string, fqdn string) {
func StartServer(address string, serverCertPath string, serverKeyPath string, caCertPath string, fqdn string) {
n32c := N32C{FQDN: FQDN(fqdn), Capabilities: []SecurityCapability{TLS}}
http.HandleFunc("/n32c-handshake/v1/exchange-capability", loggingMiddleware(n32c.HandlePostExchangeCapability))
clientCAPool, err := loadClientCAs(caCertPath)
Expand All @@ -50,14 +49,6 @@ func StartServer(ctx context.Context, address string, serverCertPath string, ser
Addr: address,
TLSConfig: tlsConfig,
}

go func() {
<-ctx.Done()
if err := server.Shutdown(context.Background()); err != nil {
log.Printf("SBI server shutdown error: %v", err)
}
}()

log.Printf("starting N32 server on %s", address)
if err := server.ListenAndServeTLS(serverCertPath, serverKeyPath); err != http.ErrServerClosed {
log.Fatalf("failed to start server: %s", err)
Expand Down
10 changes: 1 addition & 9 deletions internal/sbi/server.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sbi

import (
"context"
"log"
"net/http"
"net/http/httputil"
Expand All @@ -19,7 +18,7 @@ func newReverseProxy(targetURL string) *httputil.ReverseProxy {
return httputil.NewSingleHostReverseProxy(url)
}

func StartServer(ctx context.Context, config *config.Config) {
func StartServer(config *config.Config) {
proxy := newReverseProxy(config.SEPP.Remote.URL)
http.Handle("/", proxy)

Expand All @@ -29,13 +28,6 @@ func StartServer(ctx context.Context, config *config.Config) {
Handler: proxy,
}

go func() {
<-ctx.Done()
if err := server.Shutdown(context.Background()); err != nil {
log.Printf("SBI server shutdown error: %v", err)
}
}()

log.Printf("Starting SBI server on %s", address)
if err := server.ListenAndServeTLS(config.SEPP.Local.SBI.TLS.Cert, config.SEPP.Local.SBI.TLS.Key); err != http.ErrServerClosed {
log.Fatalf("Failed to start server: %v", err)
Expand Down

0 comments on commit d72319d

Please sign in to comment.