Skip to content

Commit

Permalink
Add X-Ray daemon (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikschubert authored Mar 21, 2023
1 parent 6fbdf2e commit bede0a0
Show file tree
Hide file tree
Showing 7 changed files with 397 additions and 48 deletions.
28 changes: 5 additions & 23 deletions cmd/localstack/awsutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package main
import (
"context"
"fmt"
"github.com/jessevdk/go-flags"
log "github.com/sirupsen/logrus"
"go.amzn.com/lambda/interop"
"go.amzn.com/lambda/rapidcore"
Expand All @@ -27,29 +26,12 @@ const (
runtimeBootstrap = "/var/runtime/bootstrap"
)

type options struct {
LogLevel string `long:"log-level" default:"info" description:"log level"`
InitCachingEnabled bool `long:"enable-init-caching" description:"Enable support for Init Caching"`
}

func getCLIArgs() (options, []string) {
var opts options
parser := flags.NewParser(&opts, flags.IgnoreUnknown)
args, err := parser.ParseArgs(os.Args)

if err != nil {
log.WithError(err).Fatal("Failed to parse command line arguments:", os.Args)
}

return opts, args
}

func isBootstrapFileExist(filePath string) bool {
file, err := os.Stat(filePath)
return !os.IsNotExist(err) && !file.IsDir()
}

func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) {
func getBootstrap(args []string) (*rapidcore.Bootstrap, string) {
var bootstrapLookupCmd []string
var handler string
currentWorkingDir := "/var/task" // default value
Expand Down Expand Up @@ -148,7 +130,7 @@ func resetListener(changeChannel <-chan bool, server *CustomInteropServer) {

func RunDNSRewriter(opts *LsOpts, ctx context.Context) {
if opts.EnableDnsServer != "1" {
log.Debugln("Dns server disabled")
log.Debugln("DNS server disabled.")
return
}
dnsForwarder, err := NewDnsForwarder(opts.LocalstackIP)
Expand All @@ -160,7 +142,7 @@ func RunDNSRewriter(opts *LsOpts, ctx context.Context) {
dnsForwarder.Start()

<-ctx.Done()
log.Debugln("Shutting down dns server")
log.Debugln("DNS server stopped")
}

func RunHotReloadingListener(server *CustomInteropServer, targetPaths []string, ctx context.Context) {
Expand Down Expand Up @@ -234,11 +216,11 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.T
// pass to rapid
sandbox.Init(&interop.Init{
Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")),
CorrelationID: "initCorrelationID",
CorrelationID: "initCorrelationID", // TODO
AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"),
AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"),
AwsSession: os.Getenv("AWS_SESSION_TOKEN"),
XRayDaemonAddress: "0.0.0.0:0", // TODO
XRayDaemonAddress: GetenvWithDefault("AWS_XRAY_DAEMON_ADDRESS", "127.0.0.1:2000"),
FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"),
FunctionVersion: functionVersion,

Expand Down
9 changes: 5 additions & 4 deletions cmd/localstack/custom_interop.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ func (l *LocalStackAdapter) SendStatus(status LocalStackStatus) error {
return nil
}

// The InvokeRequest is sent by LocalStack to trigger an invocation
type InvokeRequest struct {
InvokeId string `json:"invoke-id"`
InvokedFunctionArn string `json:"invoked-function-arn"`
Payload string `json:"payload"`
TraceId string `json:"trace-id"`
}

// The ErrorResponse is sent TO LocalStack when encountering an error
type ErrorResponse struct {
ErrorMessage string `json:"errorMessage"`
ErrorType string `json:"errorType,omitempty"`
Expand Down Expand Up @@ -95,10 +98,8 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate rapidcore.InteropServer, lo
Payload: strings.NewReader(invokeR.Payload), // r.Body,
NeedDebugLogs: true,
CorrelationID: "invokeCorrelationID",
// TODO: should we use the env _X_AMZN_TRACE_ID here or get the value from the request headers from the direct invoke?
// for now we just set a "real" static value
TraceID: "Root=1-53cfd31b-192638fa13e39d2c2bcea001;Parent=365fb4b15f2e3987;Sampled=0", // r.Header.Get("X-Amzn-Trace-Id"),
//TraceID: GetEnvOrDie("_X_AMZN_TRACE_ID"), // r.Header.Get("X-Amzn-Trace-Id"),

TraceID: invokeR.TraceId,
// TODO: set correct segment ID from request
//LambdaSegmentID: "LambdaSegmentID", // r.Header.Get("X-Amzn-Segment-Id"),
//CognitoIdentityID: "",
Expand Down
56 changes: 43 additions & 13 deletions cmd/localstack/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type LsOpts struct {
HotReloadingPaths []string
EnableDnsServer string
LocalstackIP string
InitLogLevel string
EdgePort string
}

func GetEnvOrDie(env string) string {
Expand All @@ -36,12 +38,15 @@ func GetEnvOrDie(env string) string {

func InitLsOpts() *LsOpts {
return &LsOpts{
// required
RuntimeEndpoint: GetEnvOrDie("LOCALSTACK_RUNTIME_ENDPOINT"),
RuntimeId: GetEnvOrDie("LOCALSTACK_RUNTIME_ID"),
// optional with default
InteropPort: GetenvWithDefault("LOCALSTACK_INTEROP_PORT", "9563"),
InitTracingPort: GetenvWithDefault("LOCALSTACK_RUNTIME_TRACING_PORT", "9564"),
User: GetenvWithDefault("LOCALSTACK_USER", "sbx_user1051"),
InitLogLevel: GetenvWithDefault("LOCALSTACK_INIT_LOG_LEVEL", "debug"),
EdgePort: GetenvWithDefault("EDGE_PORT", "4566"),
// optional or empty
CodeArchives: os.Getenv("LOCALSTACK_CODE_ARCHIVES"),
HotReloadingPaths: strings.Split(GetenvWithDefault("LOCALSTACK_HOT_RELOADING_PATHS", ""), ","),
Expand All @@ -62,6 +67,7 @@ func UnsetLsEnvs() {
"LOCALSTACK_CODE_ARCHIVES",
"LOCALSTACK_HOT_RELOADING_PATHS",
"LOCALSTACK_ENABLE_DNS_SERVER",
"LOCALSTACK_INIT_LOG_LEVEL",
// Docker container ID
"HOSTNAME",
// User
Expand All @@ -78,22 +84,33 @@ func main() {
// we're setting this to the same value as in the official RIE
debug.SetGCPercent(33)

// configuration parsing
lsOpts := InitLsOpts()
UnsetLsEnvs()

// set up logging (logrus)
//log.SetFormatter(&log.JSONFormatter{})
//log.SetLevel(log.TraceLevel)
log.SetLevel(log.DebugLevel)
// set up logging
log.SetReportCaller(true)
switch lsOpts.InitLogLevel {
case "debug":
log.SetLevel(log.DebugLevel)
case "trace":
log.SetFormatter(&log.JSONFormatter{})
log.SetLevel(log.TraceLevel)
default:
log.Fatal("Invalid value for LOCALSTACK_INIT_LOG_LEVEL")
}

// enable dns server
dnsServerContext, stopDnsServer := context.WithCancel(context.Background())
go RunDNSRewriter(lsOpts, dnsServerContext)

// download code archive if env variable is set
if err := DownloadCodeArchives(lsOpts.CodeArchives); err != nil {
log.Fatal("Failed to download code archives")
}
// enable dns server
dnsServerContext, stopDnsServer := context.WithCancel(context.Background())
go RunDNSRewriter(lsOpts, dnsServerContext)

// parse CLI args
bootstrap, handler := getBootstrap(os.Args)

// Switch to non-root user and drop root privileges
if IsRootUser() && lsOpts.User != "" {
Expand All @@ -108,23 +125,36 @@ func main() {
UserLogger().Debugln("Process running as non-root user.")
}

// parse CLI args
opts, args := getCLIArgs()
bootstrap, handler := getBootstrap(args, opts)
logCollector := NewLogCollector()

// file watcher for hot-reloading
fileWatcherContext, cancelFileWatcher := context.WithCancel(context.Background())

// build sandbox
sandbox := rapidcore.
NewSandboxBuilder(bootstrap).
//SetTracer(tracer).
AddShutdownFunc(func() {
log.Debugln("Closing contexts")
log.Debugln("Stopping file watcher")
cancelFileWatcher()
log.Debugln("Stopping DNS server")
stopDnsServer()
}).
AddShutdownFunc(func() { os.Exit(0) }).
SetExtensionsFlag(true).
SetInitCachingFlag(true).
SetTailLogOutput(logCollector)

// xray daemon
xrayConfig := initConfig("http://" + lsOpts.LocalstackIP + ":" + lsOpts.EdgePort)
d := initDaemon(xrayConfig)
sandbox.AddShutdownFunc(func() {
log.Debugln("Shutting down xray daemon")
d.stop()
log.Debugln("Flushing segments in xray daemon")
d.close()
})
runDaemon(d) // async

defaultInterop := sandbox.InteropServer()
interopServer := NewCustomInteropServer(lsOpts, defaultInterop, logCollector)
sandbox.SetInteropServer(interopServer)
Expand All @@ -136,7 +166,7 @@ func main() {
go sandbox.Create()

// get timeout
invokeTimeoutEnv := GetEnvOrDie("AWS_LAMBDA_FUNCTION_TIMEOUT")
invokeTimeoutEnv := GetEnvOrDie("AWS_LAMBDA_FUNCTION_TIMEOUT") // TODO: collect all AWS_* env parsing
invokeTimeoutSeconds, err := strconv.Atoi(invokeTimeoutEnv)
if err != nil {
log.Fatalln(err)
Expand Down
Loading

0 comments on commit bede0a0

Please sign in to comment.