Skip to content

Commit

Permalink
feat: support CtxOption for SetCtx
Browse files Browse the repository at this point in the history
  • Loading branch information
pepesi committed Nov 25, 2024
1 parent a759338 commit bb4f838
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 62 deletions.
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
if err == nil {
checkStream(&ctx, &log)
checkStream(&ctx, log)
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_resp_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process response headers: %v", err))
return types.ActionContinue
}

checkStream(&ctx, &log)
checkStream(&ctx, log)
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleBody && !needHandleStreamingBody {
Expand Down Expand Up @@ -254,7 +254,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
return types.ActionContinue
}

func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
func checkStream(ctx *wrapper.HttpContext, log wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil {
Expand Down
45 changes: 30 additions & 15 deletions plugins/wasm-go/pkg/wrapper/log_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,26 @@ const (
LogLevelCritical
)

type Log struct {
type Log interface {
Trace(msg string)
Tracef(format string, args ...interface{})
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Warn(msg string)
Warnf(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
Critical(msg string)
Criticalf(format string, args ...interface{})
}

type DefaultLog struct {
pluginName string
}

func (l Log) log(level LogLevel, msg string) {
func (l *DefaultLog) log(level LogLevel, msg string) {
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
requestID := string(requestIDRaw)
if requestID == "" {
Expand All @@ -58,7 +73,7 @@ func (l Log) log(level LogLevel, msg string) {
}
}

func (l Log) logFormat(level LogLevel, format string, args ...interface{}) {
func (l *DefaultLog) logFormat(level LogLevel, format string, args ...interface{}) {
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
requestID := string(requestIDRaw)
if requestID == "" {
Expand All @@ -81,50 +96,50 @@ func (l Log) logFormat(level LogLevel, format string, args ...interface{}) {
}
}

func (l Log) Trace(msg string) {
func (l *DefaultLog) Trace(msg string) {
l.log(LogLevelTrace, msg)
}

func (l Log) Tracef(format string, args ...interface{}) {
func (l *DefaultLog) Tracef(format string, args ...interface{}) {
l.logFormat(LogLevelTrace, format, args...)
}

func (l Log) Debug(msg string) {
func (l *DefaultLog) Debug(msg string) {
l.log(LogLevelDebug, msg)
}

func (l Log) Debugf(format string, args ...interface{}) {
func (l *DefaultLog) Debugf(format string, args ...interface{}) {
l.logFormat(LogLevelDebug, format, args...)
}

func (l Log) Info(msg string) {
func (l *DefaultLog) Info(msg string) {
l.log(LogLevelInfo, msg)
}

func (l Log) Infof(format string, args ...interface{}) {
func (l *DefaultLog) Infof(format string, args ...interface{}) {
l.logFormat(LogLevelInfo, format, args...)
}

func (l Log) Warn(msg string) {
func (l *DefaultLog) Warn(msg string) {
l.log(LogLevelWarn, msg)
}

func (l Log) Warnf(format string, args ...interface{}) {
func (l *DefaultLog) Warnf(format string, args ...interface{}) {
l.logFormat(LogLevelWarn, format, args...)
}

func (l Log) Error(msg string) {
func (l *DefaultLog) Error(msg string) {
l.log(LogLevelError, msg)
}

func (l Log) Errorf(format string, args ...interface{}) {
func (l *DefaultLog) Errorf(format string, args ...interface{}) {
l.logFormat(LogLevelError, format, args...)
}

func (l Log) Critical(msg string) {
func (l *DefaultLog) Critical(msg string) {
l.log(LogLevelCritical, msg)
}

func (l Log) Criticalf(format string, args ...interface{}) {
func (l *DefaultLog) Criticalf(format string, args ...interface{}) {
l.logFormat(LogLevelCritical, format, args...)
}
166 changes: 122 additions & 44 deletions plugins/wasm-go/pkg/wrapper/plugin_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,79 +98,157 @@ func RegisteTickFunc(tickPeriod int64, tickFunc func()) {
globalOnTickFuncs = append(globalOnTickFuncs, TickFuncEntry{0, tickPeriod, tickFunc})
}

func SetCtx[PluginConfig any](pluginName string, setFuncs ...SetPluginFunc[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtx(pluginName, setFuncs...))
func SetCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtx(pluginName, options...))
}

type SetPluginFunc[PluginConfig any] func(*CommonVmCtx[PluginConfig])
func SetCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtxWithOptions(pluginName, options...))
}

func ParseConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = f
}
type CtxOption[PluginConfig any] interface {
Apply(*CommonVmCtx[PluginConfig])
}

func ParseOverrideConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = f
ctx.parseRuleConfig = g
}
type parseConfigOption[PluginConfig any] struct {
f ParseConfigFunc[PluginConfig]
}

func ProcessRequestHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestHeaders = f
}
func (o parseConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = o.f
}

func ProcessRequestBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestBody = f
}
func ParseConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return parseConfigOption[PluginConfig]{f}
}

func ProcessStreamingRequestBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingRequestBody = f
}
type parseOverrideConfigOption[PluginConfig any] struct {
parseConfigF ParseConfigFunc[PluginConfig]
parseRuleConfigF ParseRuleConfigFunc[PluginConfig]
}

func ProcessResponseHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseHeaders = f
}
func (o *parseOverrideConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = o.parseConfigF
ctx.parseRuleConfig = o.parseRuleConfigF
}

func ProcessResponseBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseBody = f
}
func ParseOverrideConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return &parseOverrideConfigOption[PluginConfig]{f, g}
}

func ProcessStreamingResponseBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingResponseBody = f
}
type onProcessRequestHeadersOption[PluginConfig any] struct {
f onHttpHeadersFunc[PluginConfig]
}

func ProcessStreamDoneBy[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamDone = f
}
func (o *onProcessRequestHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestHeaders = o.f
}

func ProcessRequestHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestHeadersOption[PluginConfig]{f}
}

type onProcessRequestBodyOption[PluginConfig any] struct {
f onHttpBodyFunc[PluginConfig]
}

func (o *onProcessRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestBody = o.f
}

func ProcessRequestBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestBodyOption[PluginConfig]{f}
}

type onProcessStreamingRequestBodyOption[PluginConfig any] struct {
f onHttpStreamingBodyFunc[PluginConfig]
}

func (o *onProcessStreamingRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingRequestBody = o.f
}

func ProcessStreamingRequestBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingRequestBodyOption[PluginConfig]{f}
}

type onProcessResponseHeadersOption[PluginConfig any] struct {
f onHttpHeadersFunc[PluginConfig]
}

func (o *onProcessResponseHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseHeaders = o.f
}

func ProcessResponseHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseHeadersOption[PluginConfig]{f}
}

type onProcessResponseBodyOption[PluginConfig any] struct {
f onHttpBodyFunc[PluginConfig]
}

func (o *onProcessResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseBody = o.f
}

func ProcessResponseBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseBodyOption[PluginConfig]{f}
}

type onProcessStreamingResponseBodyOption[PluginConfig any] struct {
f onHttpStreamingBodyFunc[PluginConfig]
}

func (o *onProcessStreamingResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingResponseBody = o.f
}

func ProcessStreamingResponseBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingResponseBodyOption[PluginConfig]{f}
}

type onProcessStreamDoneOption[PluginConfig any] struct {
f onHttpStreamDoneFunc[PluginConfig]
}

func (o *onProcessStreamDoneOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamDone = o.f
}

func ProcessStreamDoneBy[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamDoneOption[PluginConfig]{f}
}

type logOption[PluginConfig any] struct {
logger Log
}

func (o *logOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.log = o.logger
}

func WithLogger[PluginConfig any](logger Log) CtxOption[PluginConfig] {
return &logOption[PluginConfig]{logger}
}

func parseEmptyPluginConfig[PluginConfig any](gjson.Result, *PluginConfig, Log) error {
return nil
}

func NewCommonVmCtx[PluginConfig any](pluginName string, setFuncs ...SetPluginFunc[PluginConfig]) *CommonVmCtx[PluginConfig] {
func NewCommonVmCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] {
logger := &DefaultLog{pluginName}
opts := append([]CtxOption[PluginConfig]{WithLogger[PluginConfig](logger)}, options...)
return NewCommonVmCtxWithOptions(pluginName, opts...)
}

func NewCommonVmCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] {
ctx := &CommonVmCtx[PluginConfig]{
pluginName: pluginName,
log: Log{pluginName},
hasCustomConfig: true,
}
for _, set := range setFuncs {
set(ctx)
for _, opt := range options {
opt.Apply(ctx)
}
if ctx.parseConfig == nil {
var config PluginConfig
Expand Down

0 comments on commit bb4f838

Please sign in to comment.