Skip to content

Commit

Permalink
Adapt to new init logic (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
joe4dev authored and dfangl committed Oct 18, 2023
1 parent 62535dc commit 605fa1c
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 122 deletions.
16 changes: 13 additions & 3 deletions cmd/localstack/awsutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
// LOCALSTACK CHANGES 2022-03-10: modified/collected file from /cmd/aws-lambda-rie/* into this util
// LOCALSTACK CHANGES 2022-03-10: minor refactoring of PrintEndReports
// LOCALSTACK CHANGES 2023-10-06: reflect getBootstrap and InitHandler API updates

package main

Expand All @@ -11,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"go.amzn.com/lambda/interop"
"go.amzn.com/lambda/rapidcore"
"go.amzn.com/lambda/rapidcore/env"
"golang.org/x/sys/unix"
"io"
"io/fs"
Expand Down Expand Up @@ -87,7 +89,7 @@ func getBootstrap(args []string) (*rapidcore.Bootstrap, string) {
}
}

return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler
return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler
}

func PrintEndReports(invokeId string, initDuration string, memorySize string, invokeStart time.Time, timeoutDuration time.Duration, w io.Writer) {
Expand Down Expand Up @@ -203,7 +205,7 @@ func getSubFoldersInList(prefix string, pathList []string) (oldFolders []string,
return
}

func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.Time, time.Time) {
func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap) (time.Time, time.Time) {
additionalFunctionEnvironmentVariables := map[string]string{}

// Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and
Expand All @@ -226,15 +228,23 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.T
// pass to rapid
sandbox.Init(&interop.Init{
Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")),
CorrelationID: "initCorrelationID", // TODO
AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"),
AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"),
AwsSession: os.Getenv("AWS_SESSION_TOKEN"),
XRayDaemonAddress: GetenvWithDefault("AWS_XRAY_DAEMON_ADDRESS", "127.0.0.1:2000"),
FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"),
FunctionVersion: functionVersion,

// TODO: Implement runtime management controls
// https://aws.amazon.com/blogs/compute/introducing-aws-lambda-runtime-management-controls/
RuntimeInfo: interop.RuntimeInfo{
ImageJSON: "{}",
Arn: "",
Version: ""},
CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables,
SandboxType: interop.SandboxClassic,
Bootstrap: bs,
EnvironmentVariables: env.NewEnvironment(),
}, timeout*1000)
initEnd := time.Now()
return initStart, initEnd
Expand Down
148 changes: 41 additions & 107 deletions cmd/localstack/custom_interop.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package main

// Original implementation: lambda/rapidcore/server.go includes Server struct with state
// Server interface between Runtime API and this init: lambda/interop/model.go:358

import (
"bytes"
"encoding/json"
"fmt"
"github.com/go-chi/chi"
log "github.com/sirupsen/logrus"
"go.amzn.com/lambda/core"
"go.amzn.com/lambda/core/statejson"
"go.amzn.com/lambda/interop"
"go.amzn.com/lambda/rapidcore"
Expand Down Expand Up @@ -38,8 +40,8 @@ const (
)

func (l *LocalStackAdapter) SendStatus(status LocalStackStatus, payload []byte) error {
status_url := fmt.Sprintf("%s/status/%s/%s", l.UpstreamEndpoint, l.RuntimeId, status)
_, err := http.Post(status_url, "application/json", bytes.NewReader(payload))
statusUrl := fmt.Sprintf("%s/status/%s/%s", l.UpstreamEndpoint, l.RuntimeId, status)
_, err := http.Post(statusUrl, "application/json", bytes.NewReader(payload))
if err != nil {
return err
}
Expand All @@ -62,7 +64,7 @@ type ErrorResponse struct {
StackTrace []string `json:"stackTrace,omitempty"`
}

func NewCustomInteropServer(lsOpts *LsOpts, delegate rapidcore.InteropServer, logCollector *LogCollector) (server *CustomInteropServer) {
func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollector *LogCollector) (server *CustomInteropServer) {
server = &CustomInteropServer{
delegate: delegate.(*rapidcore.Server),
port: lsOpts.InteropPort,
Expand Down Expand Up @@ -99,9 +101,7 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate rapidcore.InteropServer, lo
InvokedFunctionArn: invokeR.InvokedFunctionArn,
Payload: strings.NewReader(invokeR.Payload), // r.Body,
NeedDebugLogs: true,
CorrelationID: "invokeCorrelationID",

TraceID: invokeR.TraceId,
TraceID: invokeR.TraceId,
// TODO: set correct segment ID from request
//LambdaSegmentID: "LambdaSegmentID", // r.Header.Get("X-Amzn-Segment-Id"),
//CognitoIdentityID: "",
Expand Down Expand Up @@ -194,147 +194,81 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate rapidcore.InteropServer, lo
return server
}

func (c *CustomInteropServer) StartAcceptingDirectInvokes() error {
log.Traceln("Function called")
err := c.localStackAdapter.SendStatus(Ready, []byte{})
if err != nil {
return err
}
return c.delegate.StartAcceptingDirectInvokes()
func (c *CustomInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error {
log.Traceln("SendResponse called")
return c.delegate.SendResponse(invokeID, headers, reader, trailers, request)
}

func (c *CustomInteropServer) SendResponse(invokeID string, contentType string, response io.Reader) error {
log.Traceln("Function called")
return c.delegate.SendResponse(invokeID, contentType, response)
func (c *CustomInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error {
log.Traceln("SendErrorResponse called")
return c.delegate.SendErrorResponse(invokeID, response)
}

func (c *CustomInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error {
is, err := c.InternalState()
if err != nil {
return err
}
rs := is.Runtime.State
if rs.Name == core.RuntimeInitErrorStateName {
err = c.localStackAdapter.SendStatus(Error, response.Payload)
if err != nil {
return err
}
// SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT.
func (c *CustomInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error {
log.Traceln("SendInitErrorResponse called")
if err := c.localStackAdapter.SendStatus(Error, response.Payload); err != nil {
log.Fatalln("Failed to send init error to LocalStack " + err.Error() + ". Exiting.")
}

return c.delegate.SendErrorResponse(invokeID, response)
return c.delegate.SendInitErrorResponse(invokeID, response)
}

func (c *CustomInteropServer) GetCurrentInvokeID() string {
log.Traceln("Function called")
log.Traceln("GetCurrentInvokeID called")
return c.delegate.GetCurrentInvokeID()
}

func (c *CustomInteropServer) CommitResponse() error {
log.Traceln("Function called")
return c.delegate.CommitResponse()
}

func (c *CustomInteropServer) SendRunning(running *interop.Running) error {
log.Traceln("Function called")
return c.delegate.SendRunning(running)
}

func (c *CustomInteropServer) SendRuntimeReady() error {
log.Traceln("Function called")
log.Traceln("SendRuntimeReady called")
return c.delegate.SendRuntimeReady()
}

func (c *CustomInteropServer) SendDone(done *interop.Done) error {
log.Traceln("Function called")
return c.delegate.SendDone(done)
}

func (c *CustomInteropServer) SendDoneFail(fail *interop.DoneFail) error {
log.Traceln("Function called")
return c.delegate.SendDoneFail(fail)
}

func (c *CustomInteropServer) StartChan() <-chan *interop.Start {
log.Traceln("Function called")
return c.delegate.StartChan()
}

func (c *CustomInteropServer) InvokeChan() <-chan *interop.Invoke {
log.Traceln("Function called")
return c.delegate.InvokeChan()
}

func (c *CustomInteropServer) ResetChan() <-chan *interop.Reset {
log.Traceln("Function called")
return c.delegate.ResetChan()
}

func (c *CustomInteropServer) ShutdownChan() <-chan *interop.Shutdown {
log.Traceln("Function called")
return c.delegate.ShutdownChan()
}

func (c *CustomInteropServer) TransportErrorChan() <-chan error {
log.Traceln("Function called")
return c.delegate.TransportErrorChan()
}

func (c *CustomInteropServer) Clear() {
log.Traceln("Function called")
c.delegate.Clear()
}

func (c *CustomInteropServer) IsResponseSent() bool {
log.Traceln("Function called")
return c.delegate.IsResponseSent()
}

func (c *CustomInteropServer) SetInternalStateGetter(cb interop.InternalStateGetter) {
log.Traceln("Function called")
c.delegate.SetInternalStateGetter(cb)
}

func (c *CustomInteropServer) Init(i *interop.Start, invokeTimeoutMs int64) {
log.Traceln("Function called")
c.delegate.Init(i, invokeTimeoutMs)
func (c *CustomInteropServer) Init(i *interop.Init, invokeTimeoutMs int64) error {
log.Traceln("Init called")
return c.delegate.Init(i, invokeTimeoutMs)
}

func (c *CustomInteropServer) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error {
log.Traceln("Function called")
log.Traceln("Invoke called")
return c.delegate.Invoke(responseWriter, invoke)
}

func (c *CustomInteropServer) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error {
log.Traceln("Function called")
log.Traceln("FastInvoke called")
return c.delegate.FastInvoke(w, i, direct)
}

func (c *CustomInteropServer) Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) {
log.Traceln("Function called")
log.Traceln("Reserve called")
return c.delegate.Reserve(id, traceID, lambdaSegmentID)
}

func (c *CustomInteropServer) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) {
log.Traceln("Function called")
log.Traceln("Reset called")
return c.delegate.Reset(reason, timeoutMs)
}

func (c *CustomInteropServer) AwaitRelease() (*statejson.InternalStateDescription, error) {
log.Traceln("Function called")
log.Traceln("AwaitRelease called")
return c.delegate.AwaitRelease()
}

func (c *CustomInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription {
log.Traceln("Function called")
return c.delegate.Shutdown(shutdown)
}

func (c *CustomInteropServer) InternalState() (*statejson.InternalStateDescription, error) {
log.Traceln("Function called")
log.Traceln("InternalState called")
return c.delegate.InternalState()
}

func (c *CustomInteropServer) CurrentToken() *interop.Token {
log.Traceln("Function called")
log.Traceln("CurrentToken called")
return c.delegate.CurrentToken()
}

func (c *CustomInteropServer) SetSandboxContext(sbCtx interop.SandboxContext) {
log.Traceln("SetSandboxContext called")
c.delegate.SetSandboxContext(sbCtx)
}

func (c *CustomInteropServer) SetInternalStateGetter(cb interop.InternalStateGetter) {
log.Traceln("SetInternalStateGetter called")
c.delegate.InternalStateGetter = cb
}
31 changes: 31 additions & 0 deletions cmd/localstack/logs_egress_api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"io"
"os"
)

// This LocalStack LogsEgressAPI builder allows to customize log capturing, in our case using the logCollector.

type LocalStackLogsEgressAPI struct {
logCollector *LogCollector
}

func NewLocalStackLogsEgressAPI(logCollector *LogCollector) *LocalStackLogsEgressAPI {
return &LocalStackLogsEgressAPI{
logCollector: logCollector,
}
}

// The interface StdLogsEgressAPI for the functions below is defined in the under cmd/localstack/logs_egress_api.go
// The default implementation is a NoOpLogsEgressAPI

func (s *LocalStackLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) {
// os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible).
return io.MultiWriter(s.logCollector, os.Stdout), io.MultiWriter(s.logCollector, os.Stdout), nil
}

func (s *LocalStackLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) {
// os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible).
return io.MultiWriter(s.logCollector, os.Stdout), io.MultiWriter(s.logCollector, os.Stdout), nil
}
36 changes: 28 additions & 8 deletions cmd/localstack/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,15 @@ func main() {
}
}

logCollector := NewLogCollector()

// file watcher for hot-reloading
fileWatcherContext, cancelFileWatcher := context.WithCancel(context.Background())

logCollector := NewLogCollector()
localStackLogsEgressApi := NewLocalStackLogsEgressAPI(logCollector)

// build sandbox
sandbox := rapidcore.
NewSandboxBuilder(bootstrap).
NewSandboxBuilder().
//SetTracer(tracer).
AddShutdownFunc(func() {
log.Debugln("Stopping file watcher")
Expand All @@ -178,7 +179,7 @@ func main() {
}).
SetExtensionsFlag(true).
SetInitCachingFlag(true).
SetTailLogOutput(logCollector)
SetLogsEgressAPI(localStackLogsEgressApi)

// xray daemon
endpoint := "http://" + lsOpts.LocalstackIP + ":" + lsOpts.EdgePort
Expand All @@ -192,7 +193,7 @@ func main() {
})
runDaemon(d) // async

defaultInterop := sandbox.InteropServer()
defaultInterop := sandbox.DefaultInteropServer()
interopServer := NewCustomInteropServer(lsOpts, defaultInterop, logCollector)
sandbox.SetInteropServer(interopServer)
if len(handler) > 0 {
Expand All @@ -204,7 +205,10 @@ func main() {
})

// initialize all flows and start runtime API
go sandbox.Create()
sandboxContext, internalStateFn := sandbox.Create()
// Populate our custom interop server
interopServer.SetSandboxContext(sandboxContext)
interopServer.SetInternalStateGetter(internalStateFn)

// get timeout
invokeTimeoutEnv := GetEnvOrDie("AWS_LAMBDA_FUNCTION_TIMEOUT") // TODO: collect all AWS_* env parsing
Expand All @@ -214,8 +218,24 @@ func main() {
}
go RunHotReloadingListener(interopServer, lsOpts.HotReloadingPaths, fileWatcherContext)

// start runtime init
go InitHandler(sandbox, GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds)) // TODO: replace this with a custom init
// start runtime init. It is important to start `InitHandler` synchronously because we need to ensure the
// notification channels and status fields are properly initialized before `AwaitInitialized`
log.Debugln("Starting runtime init.")
InitHandler(sandbox.LambdaInvokeAPI(), GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds), bootstrap) // TODO: replace this with a custom init

log.Debugln("Awaiting initialization of runtime init.")
if err := interopServer.delegate.AwaitInitialized(); err != nil {
// Error cases: ErrInitDoneFailed or ErrInitResetReceived
log.Errorln("Runtime init failed to initialize: " + err.Error() + ". Exiting.")
// NOTE: Sending the error status to LocalStack is handled beforehand in the custom_interop.go through the
// callback SendInitErrorResponse because it contains the correct error response payload.
return
}

log.Debugln("Completed initialization of runtime init. Sending status ready to LocalStack.")
if err := interopServer.localStackAdapter.SendStatus(Ready, []byte{}); err != nil {
log.Fatalln("Failed to send status ready to LocalStack " + err.Error() + ". Exiting.")
}

<-exitChan
}
Loading

0 comments on commit 605fa1c

Please sign in to comment.