From fc388d28ae137dc1e5271dc6df5ad30cf30e5c98 Mon Sep 17 00:00:00 2001 From: Abdelrahman Ahmed Date: Sun, 26 Jul 2020 14:16:55 +0200 Subject: [PATCH] Improve routing (#64) * improve routing * update go.sum * fix allowed method calls * update readme * allow cache for post and get methods * support serving static * fix prefix lenght value * improve context comments * improve comments * update readme --- .github/workflows/test_and_build.yml | 2 +- README.md | 80 ++-- cache.go | 85 ----- cache_test.go | 42 --- context.go | 75 +++- context_test.go | 6 +- gearbox.go | 234 ++++++++---- gearbox_test.go | 322 ++++++++++++---- go.sum | 8 +- router.go | 534 ++++++++------------------- router_test.go | 523 +++----------------------- tree.go | 168 +++++++++ tree_test.go | 133 +++++++ 13 files changed, 1021 insertions(+), 1191 deletions(-) delete mode 100644 cache.go delete mode 100644 cache_test.go create mode 100644 tree.go create mode 100644 tree_test.go diff --git a/.github/workflows/test_and_build.yml b/.github/workflows/test_and_build.yml index 7130b80..c3900cf 100644 --- a/.github/workflows/test_and_build.yml +++ b/.github/workflows/test_and_build.yml @@ -43,7 +43,7 @@ jobs: - name: Build run: go build - + - name: Lint run: golint -set_exit_status ./... diff --git a/README.md b/README.md index c60903f..3b97378 100644 --- a/README.md +++ b/README.md @@ -20,17 +20,13 @@

-**gearbox** :gear: is a web framework for building micro services written in Go with a focus on high performance and memory optimization. +**gearbox** :gear: is a web framework for building micro services written in Go with a focus on high performance and memory optimization. It's built on [fasthttp](https://github.com/valyala/fasthttp) which is **up to 10x faster** than net/http -Currently, **gearbox** :gear: is **under development (not production ready)** and built on [fasthttp](https://github.com/valyala/fasthttp) which is **up to 10x faster** than net/http - -In **gearbox**, we care about peformance and memory which will be used by each method while building things up and how we can improve that. It also takes more time to **research** about each component that will be used and **compare** it with different implementations of other open source web frameworks. It may end up writing our **own components** in an optimized way to achieve our goals ### gearbox seeks to be + Secure :closed_lock_with_key: + Fast :rocket: -+ Simple :eyeglasses: -+ Easy to use ++ Easy to use :eyeglasses: + Lightweight @@ -58,8 +54,8 @@ func main() { gb := gearbox.New() // Define your handlers - gb.Get("/hello", func(ctx *gearbox.Context) { - ctx.RequestCtx.Response.SetBodyString("Hello World!") + gb.Get("/hello", func(ctx gearbox.Context) { + ctx.SendString("Hello World!") }) // Start service @@ -80,18 +76,8 @@ func main() { gb := gearbox.New() // Handler with parameter - gb.Get("/users/:user", func(ctx *gearbox.Context) { - fmt.Printf("%s\n", ctx.Params["user"]) - }) - - // Handler with optional parameter - gb.Get("/search/:pattern?", func(ctx *gearbox.Context) { - fmt.Printf("%s\n", ctx.Params["pattern"]) - }) - - // Handler with regex parameter - gb.Get("/book/:name:([a-z]+[0-3])", func(ctx *gearbox.Context) { - fmt.Printf("%s\n", ctx.Params["name"]) + gb.Get("/users/:user", func(ctx gearbox.Context) { + ctx.SendString(ctx.Param("user")) }) // Start service @@ -113,33 +99,32 @@ func main() { gb := gearbox.New() // create a logger middleware - logMiddleware := func(ctx *gearbox.Context) { - log.Printf(ctx.RequestCtx.String()) + logMiddleware := func(ctx gearbox.Context) { + log.Printf("log message!") ctx.Next() // Next is what allows the request to continue to the next middleware/handler } // create an unauthorized middleware - unAuthorizedMiddleware := func(ctx *gearbox.Context) { - ctx.RequestCtx.SetStatusCode(401) // unauthorized status code - ctx.RequestCtx.Response.SetBodyString("You are unauthorized to access this page!") + unAuthorizedMiddleware := func(ctx gearbox.Context) { + ctx.Status(gearbox.StatusUnauthorized).SendString("You are unauthorized to access this page!") } // Register the log middleware for all requests gb.Use(logMiddleware) // Define your handlers - gb.Get("/hello", func(ctx *gearbox.Context) { - ctx.RequestCtx.Response.SetBodyString("Hello World!") + gb.Get("/hello", func(ctx gearbox.Context) { + ctx.SendString("Hello World!") }) - + // Register the routes to be used when grouping routes - routes := []*gearbox.Route { - gb.Get("/id", func(ctx *gearbox.Context) { - ctx.RequestCtx.Response.SetBodyString("User X") + routes := []*gearbox.Route{ + gb.Get("/id", func(ctx gearbox.Context) { + ctx.SendString("User X") + }), + gb.Delete("/id", func(ctx gearbox.Context) { + ctx.SendString("Deleted") }), - gb.Delete("/id", func(ctx *gearbox.Context) { - ctx.RequestCtx.Response.SetBodyString("Deleted") - }) } // Group account routes @@ -150,8 +135,8 @@ func main() { // Define a route with unAuthorizedMiddleware as the middleware // you can define as many middlewares as you want and have the handler as the last argument - gb.Get("/protected", unAuthorizedMiddleware, func(ctx *gearbox.Context) { - ctx.RequestCtx.Response.SetBodyString("You accessed a protected page") + gb.Get("/protected", unAuthorizedMiddleware, func(ctx gearbox.Context) { + ctx.SendString("You accessed a protected page") }) // Start service @@ -159,6 +144,29 @@ func main() { } ``` +#### Static Files + +```go +package main + +import ( + "github.com/gogearbox/gearbox" +) + +func main() { + // Setup gearbox + gb := gearbox.New() + + // Serve files in assets directory for prefix static + // for example /static/gearbox.png, etc. + gb.Static("/static", "./assets") + + // Start service + gb.Start(":3000") +} +``` + + ### Contribute & Support + Add a [GitHub Star](https://github.com/gogearbox/gearbox/stargazers) + [Suggest new features, ideas and optimizations](https://github.com/gogearbox/gearbox/issues) diff --git a/cache.go b/cache.go deleted file mode 100644 index 4345eb1..0000000 --- a/cache.go +++ /dev/null @@ -1,85 +0,0 @@ -package gearbox - -import ( - "container/list" - "sync" -) - -// Implementation of LRU caching using doubly linked list and tst - -// Cache returns LRU cache -type Cache interface { - Set(key string, value interface{}) - Get(key string) interface{} -} - -// lruCache holds info used for caching internally -type lruCache struct { - capacity int - list *list.List - store map[string]interface{} - mutex sync.RWMutex -} - -// pair contains key and value of element -type pair struct { - key string - value interface{} -} - -// NewCache returns LRU cache -func NewCache(capacity int) Cache { - // minimum is 1 - if capacity <= 0 { - capacity = 1 - } - - return &lruCache{ - capacity: capacity, - list: new(list.List), - store: make(map[string]interface{}), - } -} - -// Get returns value of provided key if it's existing -func (c *lruCache) Get(key string) interface{} { - c.mutex.RLock() - defer c.mutex.RUnlock() - - // check if list node exists - if node, ok := c.store[key].(*list.Element); ok { - c.list.MoveToFront(node) - - return node.Value.(*pair).value - } - return nil -} - -// Set adds a value to provided key in cache -func (c *lruCache) Set(key string, value interface{}) { - c.mutex.Lock() - defer c.mutex.Unlock() - - // update the value if key is existing - if node, ok := c.store[key].(*list.Element); ok { - c.list.MoveToFront(node) - - node.Value.(*pair).value = value - return - } - - // remove last node if cache is full - if c.list.Len() == c.capacity { - lastNode := c.list.Back() - - // delete key's value - delete(c.store, lastNode.Value.(*pair).key) - - c.list.Remove(lastNode) - } - - c.store[key] = c.list.PushFront(&pair{ - key: key, - value: value, - }) -} diff --git a/cache_test.go b/cache_test.go deleted file mode 100644 index c5d0d50..0000000 --- a/cache_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package gearbox - -import ( - "fmt" -) - -// ExampleNewCache tests Cache set and get methods -func ExampleNewCache() { - cache := NewCache(3) - cache.Set("user1", 1) - fmt.Println(cache.Get("user1").(int)) - - cache.Set("user2", 2) - fmt.Println(cache.Get("user2").(int)) - - cache.Set("user3", 3) - fmt.Println(cache.Get("user3").(int)) - - cache.Set("user4", 4) - fmt.Println(cache.Get("user1")) - fmt.Println(cache.Get("user2").(int)) - - cache.Set("user5", 5) - fmt.Println(cache.Get("user3")) - - cache.Set("user5", 6) - fmt.Println(cache.Get("user5").(int)) - - cache2 := NewCache(0) - cache2.Set("user1", 1) - fmt.Println(cache2.Get("user1").(int)) - - // Output: - // 1 - // 2 - // 3 - // - // 2 - // - // 6 - // 1 -} diff --git a/context.go b/context.go index dcb6ecc..7738014 100644 --- a/context.go +++ b/context.go @@ -4,23 +4,80 @@ import ( "github.com/valyala/fasthttp" ) +// Context interface +type Context interface { + Next() + Context() *fasthttp.RequestCtx + Param(key string) string + Query(key string) string + SendString(value string) Context + Status(status int) Context + Set(key string, value string) + Get(key string) string + Body() string +} + // handlerFunc defines the handler used by middleware as return value. -type handlerFunc func(ctx *Context) +type handlerFunc func(ctx Context) // handlersChain defines a handlerFunc array. type handlersChain []handlerFunc // Context defines the current context of request and handlers/middlewares to execute -type Context struct { - RequestCtx *fasthttp.RequestCtx - Params map[string]string - handlers handlersChain - index int +type context struct { + requestCtx *fasthttp.RequestCtx + paramValues map[string]string + handlers handlersChain + index int } // Next function is used to successfully pass from current middleware to next middleware. -// if the middleware thinks it's okay to pass it. -func (ctx *Context) Next() { +// if the middleware thinks it's okay to pass it +func (ctx *context) Next() { ctx.index++ - ctx.handlers[ctx.index](ctx) + if ctx.index < len(ctx.handlers) { + ctx.handlers[ctx.index](ctx) + } +} + +// Param returns value of path parameter specified by key +func (ctx *context) Param(key string) string { + return ctx.paramValues[key] +} + +// Context returns Fasthttp context +func (ctx *context) Context() *fasthttp.RequestCtx { + return ctx.requestCtx +} + +// SendString sends body of response as a string +func (ctx *context) SendString(value string) Context { + ctx.requestCtx.SetBodyString(value) + return ctx +} + +// Status sets the HTTP status code +func (ctx *context) Status(status int) Context { + ctx.requestCtx.Response.SetStatusCode(status) + return ctx +} + +// Get returns the HTTP request header specified by field key +func (ctx *context) Get(key string) string { + return GetString(ctx.requestCtx.Request.Header.Peek(key)) +} + +// Set sets the response's HTTP header field key to the specified key, value +func (ctx *context) Set(key, value string) { + ctx.requestCtx.Response.Header.Set(key, value) +} + +// Query returns the query string parameter in the request url +func (ctx *context) Query(key string) string { + return GetString(ctx.requestCtx.QueryArgs().Peek(key)) +} + +// Body contains the raw body submitted in a POST request +func (ctx *context) Body() string { + return GetString(ctx.requestCtx.Request.Body()) } diff --git a/context_test.go b/context_test.go index 8882506..13a9289 100644 --- a/context_test.go +++ b/context_test.go @@ -6,7 +6,7 @@ import ( ) // Test passing the request from middleware to handler -func Test_Next(t *testing.T) { +func TestNext(t *testing.T) { // testing routes routes := []struct { path string @@ -17,9 +17,7 @@ func Test_Next(t *testing.T) { } // get instance of gearbox - gb := new(gearbox) - gb.registeredRoutes = make([]*Route, 0) - gb.settings = &Settings{} + gb := setupGearbox() // register routes according to method for _, r := range routes { diff --git a/gearbox.go b/gearbox.go index bfb2c40..0cc0bbe 100644 --- a/gearbox.go +++ b/gearbox.go @@ -2,10 +2,11 @@ package gearbox import ( - "fmt" "log" "net" "os" + "strings" + "sync" "time" "github.com/valyala/fasthttp" @@ -14,20 +15,17 @@ import ( // Exported constants const ( - Version = "1.0.3" // Version of gearbox + Version = "1.1.0" // Version of gearbox Name = "Gearbox" // Name of gearbox // http://patorjk.com/software/taag/#p=display&f=Big%20Money-ne&t=Gearbox banner = ` - /$$$$$$ /$$ - /$$__ $$ | $$ -| $$ \__/ /$$$$$$ /$$$$$$ /$$$$$$ | $$$$$$$ /$$$$$$ /$$ /$$ -| $$ /$$$$ /$$__ $$ |____ $$ /$$__ $$| $$__ $$ /$$__ $$| $$ /$$/ -| $$|_ $$| $$$$$$$$ /$$$$$$$| $$ \__/| $$ \ $$| $$ \ $$ \ $$$$/ -| $$ \ $$| $$_____/ /$$__ $$| $$ | $$ | $$| $$ | $$ >$$ $$ -| $$$$$$/| $$$$$$$| $$$$$$$| $$ | $$$$$$$/| $$$$$$/ /$$/\ $$ - \______/ \_______/ \_______/|__/ |_______/ \______/ |__/ \__/ %s -Listening on %s -` + _____ _ + / ____| | | +| | __ ___ __ _ _ __ | |__ ___ __ __ +| | |_ | / _ \ / _' || '__|| '_ \ / _ \\ \/ / +| |__| || __/| (_| || | | |_) || (_) |> < + \_____| \___| \__,_||_| |_.__/ \___//_/\_\ v%s +Listening on %s` ) const ( @@ -141,32 +139,44 @@ type Gearbox interface { Connect(path string, handlers ...handlerFunc) *Route Options(path string, handlers ...handlerFunc) *Route Trace(path string, handlers ...handlerFunc) *Route - Method(method, path string, handlers ...handlerFunc) *Route - Fallback(handlers ...handlerFunc) error - Use(middlewares ...handlerFunc) Group(prefix string, routes []*Route) []*Route + Static(prefix, root string) + NotFound(handlers ...handlerFunc) + Use(middlewares ...handlerFunc) } // gearbox implements Gearbox interface type gearbox struct { - httpServer *fasthttp.Server - routingTreeRoot *routeNode - registeredRoutes []*Route - address string // server address - handlers handlersChain - registeredFallback *routerFallback - cache Cache - settings *Settings + httpServer *fasthttp.Server + router *router + registeredRoutes []*Route + address string // server address + middlewares handlersChain + settings *Settings } // Settings struct holds server settings type Settings struct { // Enable case sensitive routing - CaseSensitive bool // default false + CaseInSensitive bool // default false // Maximum size of LRU cache that will be used in routing if it's enabled CacheSize int // default 1000 + // Enables answering with HTTP status code 405 if request does not match + // with any route, but there are another methods are allowed for that route + // otherwise answer with Not Found handlers or status code 404. + HandleMethodNotAllowed bool // default false + + // Enables automatic replies to OPTIONS requests if there are no handlers + // registered for that route + HandleOPTIONS bool // default false + + // Enables automatic recovering from panic while executing handlers by + // answering with HTTP status code 500 and logging error message without + // stopping service + AutoRecover bool // default false + // ServerName for sending in response headers ServerName string // default "" @@ -262,6 +272,17 @@ func New(settings ...*Settings) Gearbox { gb.settings.Concurrency = defaultConcurrency } + // Initialize router + gb.router = &router{ + settings: gb.settings, + cache: make(map[string]*matchResult), + pool: sync.Pool{ + New: func() interface{} { + return new(context) + }, + }, + } + gb.httpServer = gb.newHTTPServer() return gb @@ -269,12 +290,8 @@ func New(settings ...*Settings) Gearbox { // Start handling requests func (gb *gearbox) Start(address string) error { - // Construct routing tree - if err := gb.constructRoutingTree(); err != nil { - return fmt.Errorf("unable to construct routing %s", err.Error()) - } - - gb.cache = NewCache(gb.settings.CacheSize) + // Setup router + gb.setupRouter() if gb.settings.Prefork { if !gb.settings.DisableStartupMessage { @@ -311,12 +328,13 @@ func (gb *gearbox) Start(address string) error { type customLogger struct{} func (dl *customLogger) Printf(format string, args ...interface{}) { + //log.Printf(format) } // newHTTPServer returns a new instance of fasthttp server func (gb *gearbox) newHTTPServer() *fasthttp.Server { return &fasthttp.Server{ - Handler: gb.handler, + Handler: gb.router.Handler, Logger: &customLogger{}, LogAllErrors: false, Name: gb.settings.ServerName, @@ -332,6 +350,34 @@ func (gb *gearbox) newHTTPServer() *fasthttp.Server { } } +// registerRoute registers handlers with method and path +func (gb *gearbox) registerRoute(method, path string, handlers handlersChain) *Route { + if gb.settings.CaseInSensitive { + path = strings.ToLower(path) + } + + route := &Route{ + Path: path, + Method: method, + Handlers: handlers, + } + + // Add route to registered routes + gb.registeredRoutes = append(gb.registeredRoutes, route) + return route +} + +// setupRouter initializes router with registered routes +func (gb *gearbox) setupRouter() { + for _, route := range gb.registeredRoutes { + gb.router.handle(route.Method, route.Path, append(gb.middlewares, route.Handlers...)) + } + + // Frees intermediate stores after initializing router + gb.registeredRoutes = nil + gb.middlewares = nil +} + // Stop serving func (gb *gearbox) Stop() error { err := gb.httpServer.Shutdown() @@ -347,68 +393,50 @@ func (gb *gearbox) Stop() error { // Get registers an http relevant method func (gb *gearbox) Get(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodGet), string(path), handlers) + return gb.registerRoute(MethodGet, path, handlers) } // Head registers an http relevant method func (gb *gearbox) Head(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodHead), string(path), handlers) + return gb.registerRoute(MethodHead, path, handlers) } // Post registers an http relevant method func (gb *gearbox) Post(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodPost), string(path), handlers) + return gb.registerRoute(MethodPost, path, handlers) } // Put registers an http relevant method func (gb *gearbox) Put(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodPut), string(path), handlers) + return gb.registerRoute(MethodPut, path, handlers) } // Patch registers an http relevant method func (gb *gearbox) Patch(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodPatch), string(path), handlers) + return gb.registerRoute(MethodPatch, path, handlers) } // Delete registers an http relevant method func (gb *gearbox) Delete(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodDelete), string(path), handlers) + return gb.registerRoute(MethodDelete, path, handlers) } // Connect registers an http relevant method func (gb *gearbox) Connect(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodConnect), string(path), handlers) + return gb.registerRoute(MethodConnect, path, handlers) } // Options registers an http relevant method func (gb *gearbox) Options(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodOptions), string(path), handlers) + return gb.registerRoute(MethodOptions, path, handlers) } // Trace registers an http relevant method func (gb *gearbox) Trace(path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(MethodTrace), string(path), handlers) -} - -// Trace registers an http relevant method -func (gb *gearbox) Method(method, path string, handlers ...handlerFunc) *Route { - return gb.registerRoute(string(method), string(path), handlers) -} - -// Fallback registers an http handler only fired when no other routes match with request -func (gb *gearbox) Fallback(handlers ...handlerFunc) error { - return gb.registerFallback(handlers) + return gb.registerRoute(MethodTrace, path, handlers) } -// Use attaches a global middleware to the gearbox object. -// included in the handlers chain for all matched requests. -// it will always be executed before the handler and/or middlewares for the matched request -// For example, this is the right place for a logger or some security check or permission checking. -func (gb *gearbox) Use(middlewares ...handlerFunc) { - gb.handlers = append(gb.handlers, middlewares...) -} - -// Group appends a prefix to registered routes. +// Group appends a prefix to registered routes func (gb *gearbox) Group(prefix string, routes []*Route) []*Route { for _, route := range routes { route.Path = prefix + route.Path @@ -416,25 +444,85 @@ func (gb *gearbox) Group(prefix string, routes []*Route) []*Route { return routes } -// Handles all incoming requests and route them to proper handler according to -// method and path -func (gb *gearbox) handler(ctx *fasthttp.RequestCtx) { - if handlers, params := gb.matchRoute( - GetString(ctx.Request.Header.Method()), - GetString(ctx.URI().Path())); handlers != nil { - context := Context{ - RequestCtx: ctx, - Params: params, - handlers: append(gb.handlers, handlers...), - index: 0, +func (gb *gearbox) Static(prefix, root string) { + if gb.settings.CaseInSensitive { + prefix = strings.ToLower(prefix) + } + + // remove trailing slash + if len(root) > 1 && root[len(root)-1] == '/' { + root = root[:len(root)-1] + } + + if len(prefix) > 1 && prefix[len(prefix)-1] == '/' { + prefix = prefix[:len(prefix)-1] + } + + fs := &fasthttp.FS{ + Root: root, + IndexNames: []string{"index.html"}, + PathRewrite: func(ctx *fasthttp.RequestCtx) []byte { + path := ctx.Path() + + if len(path) >= len(prefix) { + path = path[len(prefix):] + } + + if len(path) > 0 && path[0] != '/' { + path = append([]byte("/"), path...) + } else if len(path) == 0 { + path = []byte("/") + } + return path + }, + } + + fileHandler := fs.NewRequestHandler() + handler := func(ctx Context) { + fctx := ctx.Context() + + fileHandler(fctx) + + status := fctx.Response.StatusCode() + if status != StatusNotFound && status != StatusForbidden { + return } - context.handlers[0](&context) - return + + // Pass to custom not found handlers if there are + if gb.router.notFound != nil { + gb.router.notFound[0](ctx) + return + } + + // Default Not Found response + fctx.Error(fasthttp.StatusMessage(fasthttp.StatusNotFound), + fasthttp.StatusNotFound) } - ctx.SetStatusCode(StatusNotFound) + // TODO: Improve + gb.Get(prefix, handler) + + if len(prefix) > 1 && prefix[len(prefix)-1] != '*' { + gb.Get(prefix+"/*", handler) + } +} + +// NotFound registers an http handlers that will be called when no other routes +// match with request +func (gb *gearbox) NotFound(handlers ...handlerFunc) { + gb.router.SetNotFound(handlers) +} + +// Use attaches a global middleware to the gearbox object. +// included in the handlers chain for all matched requests. +// it will always be executed before the handler and/or middlewares for the matched request +// For example, this is the right place for a logger or some security check or permission checking. +func (gb *gearbox) Use(middlewares ...handlerFunc) { + gb.middlewares = append(gb.middlewares, middlewares...) } +// printStartupMessage prints gearbox info log message in parent process +// and prints process id for child process func printStartupMessage(addr string) { if prefork.IsChild() { log.Printf("Started child proc #%v\n", os.Getpid()) diff --git a/gearbox_test.go b/gearbox_test.go index 797ac77..2fa4ed2 100644 --- a/gearbox_test.go +++ b/gearbox_test.go @@ -8,6 +8,9 @@ import ( "net" "net/http" "net/http/httputil" + "strconv" + "strings" + "sync" "testing" "time" @@ -37,17 +40,48 @@ func (c *fakeConn) Write(b []byte) (int, error) { return c.w.Write(b) } +// setupGearbox returns instace of gearbox struct +func setupGearbox(settings ...*Settings) *gearbox { + gb := new(gearbox) + gb.registeredRoutes = make([]*Route, 0) + + if len(settings) > 0 { + gb.settings = settings[0] + } else { + gb.settings = &Settings{} + } + + gb.router = &router{ + settings: gb.settings, + cache: make(map[string]*matchResult), + pool: sync.Pool{ + New: func() interface{} { + return new(context) + }, + }, + } + + return gb +} + // startGearbox constructs routing tree and creates server func startGearbox(gb *gearbox) { - gb.cache = NewCache(defaultCacheSize) - gb.constructRoutingTree() + gb.setupRouter() gb.httpServer = &fasthttp.Server{ - Handler: gb.handler, - Logger: nil, + Handler: gb.router.Handler, + Logger: &customLogger{}, LogAllErrors: false, } } +// emptyHandler just an empty handler +var emptyHandler = func(ctx Context) {} + +// empty Handlers chain is just an empty array +var emptyHandlersChain = handlersChain{} + +var fakeHandlersChain = handlersChain{emptyHandler} + // makeRequest makes an http request to http server and returns response or error func makeRequest(request *http.Request, gb *gearbox) (*http.Response, error) { // Dump request to send it @@ -78,36 +112,56 @@ func makeRequest(request *http.Request, gb *gearbox) (*http.Response, error) { if err != nil { return nil, err } - return resp, nil } // handler just an empty handler -var handler = func(ctx *Context) {} +var handler = func(ctx Context) {} + +// errorHandler contains buggy code +var errorHandler = func(ctx Context) { + m := make(map[string]int) + m["a"] = 0 + ctx.SendString(string(5 / m["a"])) +} + +// headerHandler echos header's value of key "my-header" +var headerHandler = func(ctx Context) { + ctx.Set("custom", ctx.Get("my-header")) +} + +// queryHandler answers with query's value of key "name" +var queryHandler = func(ctx Context) { + ctx.SendString(ctx.Query("name")) +} + +// bodyHandler answers with request body +var bodyHandler = func(ctx Context) { + ctx.Context().Response.SetBodyString(ctx.Body()) +} // unAuthorizedHandler sets status unauthorized in response -var unAuthorizedHandler = func(ctx *Context) { - ctx.RequestCtx.SetStatusCode(StatusUnauthorized) +var unAuthorizedHandler = func(ctx Context) { + ctx.Status(StatusUnauthorized) } // pingHandler returns string pong in response body -var pingHandler = func(ctx *Context) { - ctx.RequestCtx.Response.SetBodyString("pong") +var pingHandler = func(ctx Context) { + ctx.SendString("pong") } // fallbackHandler returns not found status with custom fallback handler in response body -var fallbackHandler = func(ctx *Context) { - ctx.RequestCtx.SetStatusCode(StatusNotFound) - ctx.RequestCtx.Response.SetBodyString("custom fallback handler") +var fallbackHandler = func(ctx Context) { + ctx.Status(StatusNotFound).SendString("custom fallback handler") } // emptyMiddleware does not stop the request and passes it to the next middleware/handler -var emptyMiddleware = func(ctx *Context) { +var emptyMiddleware = func(ctx Context) { ctx.Next() } // registerRoute matches with register route request with available methods and calls it -func registerRoute(gb Gearbox, method, path string, handler func(ctx *Context)) { +func registerRoute(gb Gearbox, method, path string, handler func(ctx Context)) { switch method { case MethodGet: gb.Get(path, handler) @@ -127,8 +181,6 @@ func registerRoute(gb Gearbox, method, path string, handler func(ctx *Context)) gb.Options(path, handler) case MethodTrace: gb.Trace(path, handler) - default: - gb.Method(method, path, handler) } } @@ -139,9 +191,13 @@ func TestMethods(t *testing.T) { routes := []struct { method string path string - handler func(ctx *Context) + handler func(ctx Context) }{ + {method: MethodGet, path: "/order/get", handler: queryHandler}, + {method: MethodPost, path: "/order/add", handler: bodyHandler}, + {method: MethodGet, path: "/books/find", handler: emptyHandler}, {method: MethodGet, path: "/articles/search", handler: emptyHandler}, + {method: MethodPut, path: "/articles/search", handler: emptyHandler}, {method: MethodHead, path: "/articles/test", handler: emptyHandler}, {method: MethodPost, path: "/articles/204", handler: emptyHandler}, {method: MethodPost, path: "/articles/205", handler: unAuthorizedHandler}, @@ -149,17 +205,17 @@ func TestMethods(t *testing.T) { {method: MethodPut, path: "/posts", handler: emptyHandler}, {method: MethodPatch, path: "/post/502", handler: emptyHandler}, {method: MethodDelete, path: "/post/a23011a", handler: emptyHandler}, - {method: MethodConnect, path: "/user/204", handler: emptyHandler}, - {method: MethodOptions, path: "/user/204/setting", handler: emptyHandler}, + {method: MethodConnect, path: "/user/204", handler: headerHandler}, + {method: MethodOptions, path: "/user/204/setting", handler: errorHandler}, {method: MethodTrace, path: "/users/*", handler: emptyHandler}, - {method: MethodTrace, path: "/users/test", handler: emptyHandler}, - {method: "CUSTOM", path: "/users/test/private", handler: emptyHandler}, } // get instance of gearbox gb := setupGearbox(&Settings{ - CaseSensitive: true, - Prefork: true, + CaseInSensitive: true, + AutoRecover: true, + HandleOPTIONS: true, + HandleMethodNotAllowed: true, }) // register routes according to method @@ -172,33 +228,44 @@ func TestMethods(t *testing.T) { // Requests that will be tested testCases := []struct { - method string - path string - statusCode int - body string + method string + path string + statusCode int + requestBody string + body string + headers map[string]string }{ + {method: MethodGet, path: "/order/get?name=art123", statusCode: StatusOK, body: "art123"}, + {method: MethodPost, path: "/order/add", requestBody: "testOrder", statusCode: StatusOK, body: "testOrder"}, + {method: MethodPost, path: "/books/find", statusCode: StatusMethodNotAllowed, body: "Method Not Allowed", headers: map[string]string{"Allow": "GET, OPTIONS"}}, {method: MethodGet, path: "/articles/search", statusCode: StatusOK}, {method: MethodGet, path: "/articles/search", statusCode: StatusOK}, {method: MethodGet, path: "/Articles/search", statusCode: StatusOK}, - {method: MethodGet, path: "/articles/searching", statusCode: StatusNotFound}, + {method: MethodOptions, path: "/articles/search", statusCode: StatusOK}, + {method: MethodOptions, path: "*", statusCode: StatusOK}, + {method: MethodOptions, path: "/*", statusCode: StatusOK}, + {method: MethodGet, path: "/articles/searching", statusCode: StatusNotFound, body: "Not Found"}, {method: MethodHead, path: "/articles/test", statusCode: StatusOK}, {method: MethodPost, path: "/articles/204", statusCode: StatusOK}, {method: MethodPost, path: "/articles/205", statusCode: StatusUnauthorized}, {method: MethodPost, path: "/Articles/205", statusCode: StatusUnauthorized}, - {method: MethodPost, path: "/articles/206", statusCode: StatusNotFound}, + {method: MethodPost, path: "/articles/206", statusCode: StatusNotFound, body: "Not Found"}, {method: MethodGet, path: "/ping", statusCode: StatusOK, body: "pong"}, {method: MethodPut, path: "/posts", statusCode: StatusOK}, {method: MethodPatch, path: "/post/502", statusCode: StatusOK}, {method: MethodDelete, path: "/post/a23011a", statusCode: StatusOK}, - {method: MethodConnect, path: "/user/204", statusCode: StatusOK}, - {method: MethodOptions, path: "/user/204/setting", statusCode: StatusOK}, + {method: MethodConnect, path: "/user/204", statusCode: StatusOK, headers: map[string]string{"custom": "testing"}}, + {method: MethodOptions, path: "/user/204/setting", statusCode: StatusInternalServerError, body: "Internal Server Error"}, {method: MethodTrace, path: "/users/testing", statusCode: StatusOK}, - {method: "CUSTOM", path: "/users/test/private", statusCode: StatusOK}, } for _, tc := range testCases { // create and make http request - req, _ := http.NewRequest(tc.method, tc.path, nil) + req, _ := http.NewRequest(tc.method, tc.path, strings.NewReader(tc.requestBody)) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(tc.requestBody))) + req.Header.Set("my-header", "testing") + response, err := makeRequest(req, gb) if err != nil { @@ -220,13 +287,73 @@ func TestMethods(t *testing.T) { if string(body) != tc.body { t.Fatalf("%s(%s): returned %s expected %s", tc.method, tc.path, body, tc.body) } + + for expectedKey, expectedValue := range tc.headers { + actualValue := response.Header.Get(expectedKey) + if actualValue != expectedValue { + t.Errorf(" mismatch for route '%s' parameter '%s' actual '%s', expected '%s'", + tc.path, expectedKey, actualValue, expectedValue) + } + } } } -// TestStart tests start service method -func TestStart(t *testing.T) { +func TestStatic(t *testing.T) { + // get instance of gearbox + gb := setupGearbox(&Settings{ + CaseInSensitive: true, + AutoRecover: true, + HandleOPTIONS: true, + HandleMethodNotAllowed: true, + }) + + gb.Static("/static/", "./assets/") + + // start serving + startGearbox(gb) + + // Requests that will be tested + testCases := []struct { + method string + path string + statusCode int + body string + }{ + {method: MethodGet, path: "/static/gearbox.png", statusCode: StatusOK}, + } + + for _, tc := range testCases { + // create and make http request + req, _ := http.NewRequest(tc.method, tc.path, nil) + + response, err := makeRequest(req, gb) + + if err != nil { + t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error()) + } + + // check status code + if response.StatusCode != tc.statusCode { + t.Fatalf("%s(%s): returned %d expected %d", tc.method, tc.path, response.StatusCode, tc.statusCode) + } + + // read body from response + body, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error()) + } + + // check response body + if tc.body != "" && string(body) != tc.body { + t.Fatalf("%s(%s): returned %s expected %s", tc.method, tc.path, body, tc.body) + } + } +} + +// TestStartWithPrefork tests start service method +func TestStartWithPrefork(t *testing.T) { gb := New(&Settings{ - DisableStartupMessage: true, + Prefork: true, }) go func() { @@ -238,23 +365,35 @@ func TestStart(t *testing.T) { } // TestStart tests start service method +func TestStart(t *testing.T) { + gb := New() + + go func() { + time.Sleep(1000 * time.Millisecond) + gb.Stop() + }() + + gb.Start(":3010") +} + +// TestStartWithTLS tests start service method func TestStartWithTLS(t *testing.T) { gb := New(&Settings{ - DisableStartupMessage: true, - TLSKeyPath: "ssl-cert-snakeoil.key", - TLSCertPath: "ssl-cert-snakeoil.crt", - TLSEnabled: true, + TLSKeyPath: "ssl-cert-snakeoil.key", + TLSCertPath: "ssl-cert-snakeoil.crt", + TLSEnabled: true, }) + // use a channel to hand off the error ( if any ) errs := make(chan error, 1) go func() { - time.Sleep(1000 * time.Millisecond) - _, err := tls.DialWithDialer(&net.Dialer{ - Timeout: time.Second * 10, - }, + _, err := tls.DialWithDialer( + &net.Dialer{ + Timeout: time.Second * 3, + }, "tcp", - "localhost:3000", + "localhost:3050", &tls.Config{ InsecureSkipVerify: true, }) @@ -262,7 +401,7 @@ func TestStartWithTLS(t *testing.T) { gb.Stop() }() - gb.Start(":3000") + gb.Start(":3050") // wait for an error err := <-errs @@ -285,23 +424,6 @@ func TestStartInvalidListener(t *testing.T) { } } -// TestStartConflictHandlers tests start with two handlers for the same path and method -func TestStartConflictHandlers(t *testing.T) { - gb := New() - - gb.Get("/test", handler) - gb.Get("/test", handler) - - go func() { - time.Sleep(1000 * time.Millisecond) - gb.Stop() - }() - - if err := gb.Start(":3001"); err == nil { - t.Fatalf("invalid listener passed") - } -} - // TestStop tests stop service method func TestStop(t *testing.T) { gb := New() @@ -315,17 +437,15 @@ func TestStop(t *testing.T) { } // TestRegisterFallback tests router fallback handler -func TestRegisterFallback(t *testing.T) { +func TestNotFound(t *testing.T) { // get instance of gearbox - gb := new(gearbox) - gb.registeredRoutes = make([]*Route, 0) - gb.settings = &Settings{} + gb := setupGearbox() // register valid route gb.Get("/ping", pingHandler) - // register our fallback - gb.Fallback(fallbackHandler) + // register not found handlers + gb.NotFound(fallbackHandler) // start serving startGearbox(gb) @@ -368,12 +488,64 @@ func TestRegisterFallback(t *testing.T) { } } -// Test Use function to try to register middlewares that work before all routes -func Test_Use(t *testing.T) { +// TestGroupRouting tests that you can do group routing +func TestGroupRouting(t *testing.T) { + // create gearbox instance + gb := setupGearbox() + routes := []*Route{ + gb.Get("/id", emptyHandler), + gb.Post("/abc", emptyHandler), + gb.Post("/abcd", emptyHandler), + } + gb.Group("/account", gb.Group("/api", routes)) + + // start serving + startGearbox(gb) + + // One valid request, one invalid + testCases := []struct { + method string + path string + statusCode int + body string + }{ + {method: MethodGet, path: "/account/api/id", statusCode: StatusOK}, + {method: MethodPost, path: "/account/api/abc", statusCode: StatusOK}, + {method: MethodPost, path: "/account/api/abcd", statusCode: StatusOK}, + {method: MethodGet, path: "/id", statusCode: StatusNotFound, body: "Not Found"}, + } + + for _, tc := range testCases { + // create and make http request + req, _ := http.NewRequest(tc.method, tc.path, nil) + response, err := makeRequest(req, gb) + + if err != nil { + t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error()) + } + + // check status code + if response.StatusCode != tc.statusCode { + t.Fatalf("%s(%s): returned %d expected %d", tc.method, tc.path, response.StatusCode, tc.statusCode) + } + + // read body from response + body, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error()) + } + + // check response body + if string(body) != tc.body { + t.Fatalf("%s(%s): returned %s expected %s", tc.method, tc.path, body, tc.body) + } + } +} + +// TestUse tries to register middlewares that work before all routes +func TestUse(t *testing.T) { // get instance of gearbox - gb := new(gearbox) - gb.registeredRoutes = make([]*Route, 0) - gb.settings = &Settings{} + gb := setupGearbox() // register valid route gb.Get("/ping", pingHandler) diff --git a/go.sum b/go.sum index 17894af..c880dea 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,17 @@ github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4= github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= -github.com/klauspost/compress v1.10.4 h1:jFzIFaf586tquEB5EhzQG0HwGNSlgAJpG53G6Ss11wc= -github.com/klauspost/compress v1.10.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg= github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.14.0 h1:67bfuW9azCMwW/Jlq/C+VeihNpAuJMWkYPBig1gdi3A= -github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0zHzaclZhE= -github.com/valyala/fasthttp v1.15.0 h1:U2WxMim6rzae8NgjlDXNeH84c7XBBTeUcq+YXh5k+ok= -github.com/valyala/fasthttp v1.15.0/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA= github.com/valyala/fasthttp v1.15.1 h1:eRb5jzWhbCn/cGu3gNJMcOfPUfXgXCcQIOHjh9ajAS8= github.com/valyala/fasthttp v1.15.1/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a h1:0R4NLDRDZX6JcmhJgXi5E4b8Wg84ihbmUKp/GvSPEzc= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y= golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/router.go b/router.go index 7a89fc4..84dedd0 100644 --- a/router.go +++ b/router.go @@ -1,454 +1,212 @@ package gearbox import ( - "fmt" - "regexp" - "sort" + "log" "strings" "sync" -) - -type routeNode struct { - Name string - Endpoints map[string][]*endpoint - Children map[string]*routeNode -} - -type paramType uint8 -// Supported parameter types -const ( - ptNoParam paramType = iota // No parameter (most strict) - ptRegexp // Regex parameter - ptParam // Normal parameter - ptMatchAll // Match all parameter (least strict) + "github.com/valyala/fasthttp" ) -type param struct { - Name string - Value string - Type paramType - IsOptional bool -} - -type endpoint struct { - Params []*param - Handlers handlersChain -} - -type routerFallback struct { - Handlers handlersChain -} - -type matchParamsResult struct { - Matched bool - Handlers handlersChain - Params map[string]string -} - -// validateRoutePath makes sure that path complies with path's rules -func validateRoutePath(path string) error { - // Check length of the path - length := len(path) - if length == 0 { - return fmt.Errorf("length is zero") - } - - if length == defaultMaxRequestURLLength { - return fmt.Errorf("length request url exceed the max limit") - } - - // Make sure path starts with / - if path[0] != '/' { - return fmt.Errorf("path must start with /") - } - - params := make(map[string]bool) - parts := strings.Split(trimPath(path), "/") - partsLen := len(parts) - for i := 0; i < partsLen; i++ { - if parts[i] == "" { - continue - } - if p := parseParameter(parts[i]); p != nil { - if p.Type == ptMatchAll && i != partsLen-1 { - return fmt.Errorf("* must be in the end of path") - } else if p.IsOptional && i != partsLen-1 { - return fmt.Errorf("only last parameter can be optional") - } else if p.Type == ptParam || p.Type == ptRegexp { - if _, ok := params[p.Name]; ok { - return fmt.Errorf("parameter is duplicated") - } - params[p.Name] = true - } - } - } - - return nil -} - -// registerRoute registers handler with method and path -func (gb *gearbox) registerRoute(method, path string, handlers handlersChain) *Route { - - if gb.settings.CaseSensitive { - path = strings.ToLower(path) - } - - route := &Route{ - Path: path, - Method: method, - Handlers: handlers, - } +var ( + defaultContentType = []byte("text/plain; charset=utf-8") +) - // Add route to registered routes - gb.registeredRoutes = append(gb.registeredRoutes, route) - return route +type router struct { + trees map[string]*node + cache map[string]*matchResult + cacheLen int + mutex sync.RWMutex + notFound handlersChain + settings *Settings + pool sync.Pool } -// registerFallback registers a single handler that will match only if all other routes fail to match -func (gb *gearbox) registerFallback(handlers handlersChain) error { - // Handler is not provided - if handlers == nil { - return fmt.Errorf("fallback does not contain a handler") - } - - gb.registeredFallback = &routerFallback{Handlers: handlers} - return nil +type matchResult struct { + handlers handlersChain + params map[string]string } -// createEmptyRouteNode creates a new route node with name -func createEmptyRouteNode(name string) *routeNode { - return &routeNode{ - Name: name, - Children: make(map[string]*routeNode), - Endpoints: make(map[string][]*endpoint), - } -} - -// parseParameter parses path part into param struct, or returns nil if it's -// not a parameter -func parseParameter(pathPart string) *param { - pathPartLen := len(pathPart) - if pathPartLen == 0 { - return nil - } - - // match all - if pathPart[0] == '*' { - return ¶m{ - Name: "*", - Type: ptMatchAll, - } - } - - isOptional := pathPart[pathPartLen-1] == '?' - if isOptional { - pathPart = pathPart[0 : pathPartLen-1] - } +// acquireCtx returns instance of context after initializing it +func (r *router) acquireCtx(fctx *fasthttp.RequestCtx) *context { + ctx := r.pool.Get().(*context) - params := strings.Split(pathPart, ":") - paramsLen := len(params) + // Initialize + ctx.index = 0 + ctx.paramValues = make(map[string]string) + ctx.requestCtx = fctx - if paramsLen == 2 && params[0] == "" { // Just a parameter - return ¶m{ - Name: params[1], - Type: ptParam, - IsOptional: isOptional, - } - } else if paramsLen == 3 && params[0] == "" { // Regex parameter - return ¶m{ - Name: params[1], - Value: params[2], - Type: ptRegexp, - IsOptional: isOptional, - } - } - - return nil + return ctx } -// getLeastStrictParamType returns least strict parameter type from list of -// parameters -func getLeastStrictParamType(params []*param) paramType { - pLen := len(params) - if pLen == 0 { - return ptNoParam - } - - pType := params[0].Type - for i := 1; i < pLen; i++ { - if params[i].Type > pType { - pType = params[i].Type - } - } - return pType +// releaseCtx frees context +func (r *router) releaseCtx(ctx *context) { + ctx.handlers = nil + ctx.paramValues = nil + ctx.requestCtx = nil + r.pool.Put(ctx) } -func isValidEndpoint(endpoints []*endpoint, newEndpoint *endpoint) bool { - for i := range endpoints { - if len(endpoints[i].Params) == len(newEndpoint.Params) { - isValid := false - for j := range endpoints[i].Params { - if endpoints[i].Params[j].Type != newEndpoint.Params[j].Type { - isValid = true - } - } - return isValid - } +// handle registers handlers for provided method and path to be used +// in routing incoming requests +func (r *router) handle(method, path string, handlers handlersChain) { + if path == "" { + panic("path is empty") + } else if method == "" { + panic("method is empty") + } else if path[0] != '/' { + panic("path must begin with '/' in path '" + path + "'") + } else if len(handlers) == 0 { + panic("no handlers provided with path '" + path + "'") } - return true -} - -// trimPath trims left and right slashes in path -func trimPath(path string) string { - pathLastIndex := len(path) - 1 - for path[pathLastIndex] == '/' && pathLastIndex > 0 { - pathLastIndex-- + // initialize tree if it's empty + if r.trees == nil { + r.trees = make(map[string]*node) } - pathFirstIndex := 1 - if path[0] != '/' { - pathFirstIndex = 0 + // get root of method if it's existing, otherwise creates it + root := r.trees[method] + if root == nil { + root = createRootNode() + r.trees[method] = root } - return path[pathFirstIndex : pathLastIndex+1] + root.addRoute(path, handlers) } -// constructRoutingTree constructs routing tree according to provided routes -func (gb *gearbox) constructRoutingTree() error { - // Firstly, create root node - gb.routingTreeRoot = createEmptyRouteNode("root") - - for _, route := range gb.registeredRoutes { - currentNode := gb.routingTreeRoot +// allowed checks if provided path can be routed in another method(s) +func (r *router) allowed(reqMethod, path string, ctx *context) string { + var allow string - // Handler is not provided - if route.Handlers == nil { - return fmt.Errorf("route %s with method %s does not contain any handlers", route.Path, route.Method) - } - - // Check if path is valid or not - if err := validateRoutePath(route.Path); err != nil { - return fmt.Errorf("route %s is not valid! - %s", route.Path, err.Error()) - } - - params := make([]*param, 0) - - // Split path into slices of parts - parts := strings.Split(route.Path, "/") - - partsLen := len(parts) - for i := 1; i < partsLen; i++ { - part := parts[i] - - // Do not create node if part is empty - if part == "" { + pathLen := len(path) + if (pathLen == 1 && path[0] == '*') || (pathLen > 1 && path[1] == '*') { + for method := range r.trees { + if method == MethodOptions { continue } - // Parse part as a parameter if it is - if param := parseParameter(part); param != nil { - params = append(params, param) - continue + if allow != "" { + allow += ", " + method + } else { + allow = method } - - // Try to get a child of current node with part, otherwise - //creates a new node and make it current node - partNode, ok := currentNode.Children[part] - if !ok { - partNode = createEmptyRouteNode(part) - currentNode.Children[part] = partNode - } - currentNode = partNode - } - - currentEndpoint := &endpoint{ - Handlers: route.Handlers, - Params: params, } - - // Make sure that current node does not have a handler for route's method - var endpoints []*endpoint - if result, ok := currentNode.Endpoints[route.Method]; ok { - if ok := isValidEndpoint(result, currentEndpoint); !ok { - return fmt.Errorf("there already registered method %s for %s", route.Method, route.Path) - } - - endpoints = append(result, currentEndpoint) - sort.Slice(endpoints, func(i, j int) bool { - iLen := len(endpoints[i].Params) - jLen := len(endpoints[j].Params) - if iLen == jLen { - iParamType := getLeastStrictParamType(endpoints[i].Params) - jParamType := getLeastStrictParamType(endpoints[j].Params) - return iParamType < jParamType - } - - return iLen > jLen - }) - } else { - endpoints = []*endpoint{currentEndpoint} - } - - // Save handler to route's method for current node - currentNode.Endpoints[route.Method] = endpoints - } - return nil -} - -// matchRoute matches provided method and path with handler if it's existing -func (gb *gearbox) matchRoute(method, path string) (handlersChain, map[string]string) { - if handlers, params := gb.matchRouteAgainstRegistered(method, path); handlers != nil { - return handlers, params + return allow } - if gb.registeredFallback != nil && gb.registeredFallback.Handlers != nil { - return gb.registeredFallback.Handlers, make(map[string]string) - } - - return nil, nil -} - -func matchEndpointParams(ep *endpoint, paths []string, pathIndex int) (map[string]string, bool) { - endpointParams := ep.Params - endpointParamsLen := len(endpointParams) - pathsLen := len(paths) - paramDic := make(map[string]string, endpointParamsLen) - - paramIdx := 0 - for paramIdx < endpointParamsLen { - if endpointParams[paramIdx].Type == ptMatchAll { - // Last parameter, so we can return - return paramDic, true - } - - // path has ended and there is more parameters to match - if pathIndex >= pathsLen { - // If it's optional means this is the last parameter. - if endpointParams[paramIdx].IsOptional { - return paramDic, true - } - - return nil, false - } - - if paths[pathIndex] == "" { - pathIndex++ + for method, tree := range r.trees { + if method == reqMethod || method == MethodOptions { continue } - if endpointParams[paramIdx].Type == ptParam { - paramDic[endpointParams[paramIdx].Name] = paths[pathIndex] - } else if endpointParams[paramIdx].Type == ptRegexp { - if match, _ := regexp.MatchString(endpointParams[paramIdx].Value, paths[pathIndex]); match { - paramDic[endpointParams[paramIdx].Name] = paths[pathIndex] - } else if !endpointParams[paramIdx].IsOptional { - return nil, false + handlers := tree.matchRoute(path, ctx) + if handlers != nil { + if allow != "" { + allow += ", " + method + } else { + allow = method } } - - paramIdx++ - pathIndex++ } - for pathIndex < pathsLen && paths[pathIndex] == "" { - pathIndex++ + if len(allow) > 0 { + allow += ", " + MethodOptions } - - // There is more parts, so no match - if pathsLen-pathIndex > 0 { - return nil, false - } - - return paramDic, true + return allow } -func matchNodeEndpoints(node *routeNode, method string, paths []string, - pathIndex int, result *matchParamsResult, wg *sync.WaitGroup) { - if endpoints, ok := node.Endpoints[method]; ok { - for j := range endpoints { - if params, matched := matchEndpointParams(endpoints[j], paths, pathIndex); matched { - result.Matched = true - result.Params = params - result.Handlers = endpoints[j].Handlers - wg.Done() - return +// Handler handles all incoming requests +func (r *router) Handler(fctx *fasthttp.RequestCtx) { + context := r.acquireCtx(fctx) + defer r.releaseCtx(context) + + if r.settings.AutoRecover { + defer func(fctx *fasthttp.RequestCtx) { + if rcv := recover(); rcv != nil { + log.Printf("recovered from error: %v", rcv) + fctx.Error(fasthttp.StatusMessage(fasthttp.StatusInternalServerError), + fasthttp.StatusInternalServerError) } - } + }(fctx) } - result.Matched = false - wg.Done() -} - -func (gb *gearbox) matchRouteAgainstRegistered(method, path string) (handlersChain, map[string]string) { - // Start with root node - currentNode := gb.routingTreeRoot + path := GetString(fctx.URI().PathOriginal()) - // Return if root is empty, or path is not valid - if currentNode == nil || path == "" || path[0] != '/' || len(path) > defaultMaxRequestURLLength { - return nil, nil - } - - if gb.settings.CaseSensitive { + if r.settings.CaseInSensitive { path = strings.ToLower(path) } - trimmedPath := trimPath(path) + method := GetString(fctx.Method()) + + var cacheKey string + useCache := !r.settings.DisableCaching && + (method == MethodGet || method == MethodPost) + if useCache { + cacheKey = path + method + r.mutex.RLock() + cacheResult, ok := r.cache[cacheKey] - // Try to get from cache if it's enabled - cacheKey := "" - if !gb.settings.DisableCaching { - cacheKey = method + trimmedPath - if cacheResult, ok := gb.cache.Get(cacheKey).(*matchParamsResult); ok { - return cacheResult.Handlers, cacheResult.Params + if ok { + context.handlers = cacheResult.handlers + context.paramValues = cacheResult.params + r.mutex.RUnlock() + context.handlers[0](context) + return } + r.mutex.RUnlock() } - paths := strings.Split(trimmedPath, "/") + if root := r.trees[method]; root != nil { + if handlers := root.matchRoute(path, context); handlers != nil { + context.handlers = handlers + context.handlers[0](context) - var wg sync.WaitGroup - lastMatchedNodes := []*matchParamsResult{{}} - lastMatchedNodesIndex := 1 - wg.Add(1) - go matchNodeEndpoints(currentNode, method, paths, 0, lastMatchedNodes[0], &wg) + if useCache { + r.mutex.Lock() - for i := range paths { - if paths[i] == "" { - continue + if r.cacheLen == r.settings.CacheSize { + r.cache = make(map[string]*matchResult) + r.cacheLen = 0 + } + r.cache[cacheKey] = &matchResult{ + handlers: handlers, + params: context.paramValues, + } + r.cacheLen++ + r.mutex.Unlock() + } + return } + } - // Try to match part with a child of current node - pathNode, ok := currentNode.Children[paths[i]] - if !ok { - break + if method == MethodOptions && r.settings.HandleOPTIONS { + if allow := r.allowed(method, path, context); len(allow) > 0 { + fctx.Response.Header.Set("Allow", allow) + return + } + } else if r.settings.HandleMethodNotAllowed { + if allow := r.allowed(method, path, context); len(allow) > 0 { + fctx.Response.Header.Set("Allow", allow) + fctx.SetStatusCode(fasthttp.StatusMethodNotAllowed) + fctx.SetContentTypeBytes(defaultContentType) + fctx.SetBodyString(fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed)) + return } - - // set matched node as current node - currentNode = pathNode - - v := matchParamsResult{} - lastMatchedNodes = append(lastMatchedNodes, &v) - wg.Add(1) - go matchNodeEndpoints(currentNode, method, paths, i+1, &v, &wg) - lastMatchedNodesIndex++ } - wg.Wait() - - // Return longest prefix match - for i := lastMatchedNodesIndex - 1; i >= 0; i-- { - if lastMatchedNodes[i].Matched { - if !gb.settings.DisableCaching { - go func(key string, matchResult *matchParamsResult) { - gb.cache.Set(key, matchResult) - }(string(cacheKey), lastMatchedNodes[i]) - } - - return lastMatchedNodes[i].Handlers, lastMatchedNodes[i].Params - } + // Custom Not Found (404) handlers + if r.notFound != nil { + r.notFound[0](context) + return } - return nil, nil + // Default Not Found response + fctx.Error(fasthttp.StatusMessage(fasthttp.StatusNotFound), + fasthttp.StatusNotFound) +} + +// SetNotFound appends handlers to custom not found (404) handlers +func (r *router) SetNotFound(handlers handlersChain) { + r.notFound = append(r.notFound, handlers...) } diff --git a/router_test.go b/router_test.go index 66f8673..75037b6 100644 --- a/router_test.go +++ b/router_test.go @@ -1,493 +1,74 @@ package gearbox import ( - "fmt" + "sync" "testing" ) -// setupGearbox returns instace of gearbox struct -func setupGearbox(settings ...*Settings) *gearbox { - gb := new(gearbox) - gb.registeredRoutes = make([]*Route, 0) - - if len(settings) > 0 { - gb.settings = settings[0] - } else { - gb.settings = &Settings{} - } - - gb.cache = NewCache(defaultCacheSize) - return gb -} - -// TestValidateRoutePath tests if provided paths are valid or not -func TestValidateRoutePath(t *testing.T) { - // test cases - tests := []struct { - input string - isErr bool - }{ - {input: "", isErr: true}, - {input: "user", isErr: true}, - {input: "/user", isErr: false}, - {input: "/admin/", isErr: false}, - {input: "/user/*/get", isErr: true}, - {input: "/user/*", isErr: false}, - {input: "/user/:name", isErr: false}, - {input: "/user/:name/:name", isErr: true}, - {input: "/user/:name?/get", isErr: true}, - } - - for _, tt := range tests { - err := validateRoutePath(tt.input) - if (err != nil && !tt.isErr) || (err == nil && tt.isErr) { - errMsg := "" - - // get error message if there is - if err != nil { - errMsg = err.Error() - } - - t.Errorf("input %s find error %t %s expecting error %t", tt.input, err == nil, errMsg, tt.isErr) - } - } -} - -// TestCreateEmptyNode tests creating route node with specific name -func TestCreateEmptyNode(t *testing.T) { - name := "test_node" - node := createEmptyRouteNode(name) - - if node == nil { - // node.Name != name { - t.Errorf("find name %s expecting name %s", node.Name, name) - } -} - -// emptyHandler just an empty handler -var emptyHandler = func(ctx *Context) {} - -// empty Handlers chain is just an empty array -var emptyHandlersChain = handlersChain{} - -// TestRegisterRoute tests registering routes after validating it -func TestRegisterRoute(t *testing.T) { - // test cases - tests := []struct { - method string - path string - handler handlersChain - isErr bool - }{ - {method: MethodPut, path: "/admin/welcome", handler: emptyHandlersChain, isErr: false}, - {method: MethodPost, path: "/user/add", handler: emptyHandlersChain, isErr: false}, - {method: MethodGet, path: "/account/get", handler: emptyHandlersChain, isErr: false}, - {method: MethodGet, path: "/account/*", handler: emptyHandlersChain, isErr: false}, - {method: MethodGet, path: "/account/*", handler: emptyHandlersChain, isErr: false}, - {method: MethodDelete, path: "/account/delete", handler: emptyHandlersChain, isErr: false}, - {method: MethodDelete, path: "/account/delete", handler: nil, isErr: true}, - {method: MethodGet, path: "/account/*/getAccount", handler: nil, isErr: true}, - {method: MethodGet, path: "/books/:name/:test", handler: emptyHandlersChain, isErr: false}, - {method: MethodGet, path: "/books/:name/:name", handler: nil, isErr: true}, - } - - // counter for valid routes - validCounter := 0 - - for _, tt := range tests { - // create gearbox instance so old errors don't affect new routes - gb := setupGearbox() - - gb.registerRoute(tt.method, tt.path, tt.handler) - err := gb.constructRoutingTree() - if (err != nil && !tt.isErr) || (err == nil && tt.isErr) { - errMsg := "" - - // get error message if there is - if err != nil { - errMsg = err.Error() - } - - t.Errorf("input %v find error %t %s expecting error %t", tt, err == nil, errMsg, tt.isErr) - } - - if !tt.isErr { - validCounter++ - } - } -} - -// TestRegisterInvalidRoute tests registering invalid routes -func TestRegisterInvalidRoute(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - - // test handler is nil - gb.registerRoute(MethodGet, "invalid Path", emptyHandlersChain) - - if err := gb.constructRoutingTree(); err == nil { - t.Errorf("input GET invalid Path find nil expecting error") - } -} - -// TestParseParameter tests parsing parameters into param struct -func TestParseParameter(t *testing.T) { - tests := []struct { - path string - output *param - }{ - {path: ":test", output: ¶m{Name: "test", Type: ptParam}}, - {path: ":test2:[a-z]", output: ¶m{Name: "test2", Value: "[a-z]", Type: ptRegexp}}, - {path: "*", output: ¶m{Name: "*", Type: ptMatchAll}}, - {path: "user:[a-z]", output: nil}, - {path: "user", output: nil}, - {path: "", output: nil}, - } - - for _, test := range tests { - p := parseParameter(test.path) - if test.output == nil && p == nil { - continue - } - if (test.output == nil && p != nil) || - (test.output != nil && p == nil) || - (test.output.Name != p.Name) || - (test.output.Type != p.Type) || - (test.output.Value != p.Value) { - t.Errorf("path %s, find %v expected %v", test.path, p, test.output) - } - } -} - -// TestGetLeastStrictParamType test -func TestGetLeastStrictParamType(t *testing.T) { - tests := []struct { - params []*param - output paramType - }{ - {params: []*param{}, output: ptNoParam}, - {params: []*param{ - {Type: ptParam, Name: "name"}, - {Type: ptRegexp, Name: "test", Value: "[a-z]"}, - {Type: ptMatchAll, Name: "*"}, - }, output: ptMatchAll}, - {params: []*param{ - {Type: ptParam, Name: "name"}, - {Type: ptMatchAll, Name: "*"}, - }, output: ptMatchAll}, - {params: []*param{ - {Type: ptParam, Name: "name"}, - {Type: ptRegexp, Name: "test3", Value: "[a-z]"}, - }, output: ptParam}, - {params: []*param{ - {Type: ptRegexp, Name: "test3", Value: "[a-z]"}, - {Type: ptMatchAll, Name: "*"}, - }, output: ptMatchAll}, - } - - for _, test := range tests { - paramType := getLeastStrictParamType(test.params) - if paramType != test.output { - t.Errorf("params %v, find %d expected %d", test.params, paramType, test.output) - } - } -} - -// TestTrimPath test -func TestTrimPath(t *testing.T) { - tests := []struct { - input string - output string - }{ - {input: "/", output: ""}, - {input: "/test/", output: "test"}, - {input: "test2/", output: "test2"}, - {input: "test2", output: "test2"}, - {input: "/user/test", output: "user/test"}, - {input: "/books/get/test/", output: "books/get/test"}, - } - - for _, test := range tests { - trimmedPath := trimPath(test.input) - if trimmedPath != test.output { - t.Errorf("path %s, find %s expected %s", test.input, trimmedPath, test.output) - } - } -} - -// TestIsValidEndpoint test -func TestIsValidEndpoint(t *testing.T) { - tests := []struct { - endpoints []*endpoint - newEndpoint *endpoint - output bool - }{ - {endpoints: []*endpoint{}, newEndpoint: &endpoint{}, output: true}, - {endpoints: []*endpoint{ - {Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "user", Type: ptParam}, - {Name: "name", Type: ptParam}, - }}, - {Handlers: handlersChain{emptyHandler}, Params: []*param{}}, - }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "test", Type: ptParam}, - }}, output: true}, - {endpoints: []*endpoint{}, newEndpoint: &endpoint{}, output: true}, - {endpoints: []*endpoint{ - {Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "user", Type: ptParam}, - }}, - {Handlers: handlersChain{emptyHandler}, Params: []*param{}}, - }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "test", Type: ptParam}, - }}, output: false}, - {endpoints: []*endpoint{ - {Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "user", Type: ptRegexp, Value: "[a-z]"}, - }}, - {Handlers: handlersChain{emptyHandler}, Params: []*param{}}, - }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "test", Type: ptParam}, - }}, output: true}, - {endpoints: []*endpoint{ - {Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "*", Type: ptMatchAll}, - }}, - {Handlers: handlersChain{emptyHandler}, Params: []*param{}}, - }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{ - {Name: "test", Type: ptRegexp, Value: "[a-z]"}, - }}, output: true}, - } - - for _, test := range tests { - isValid := isValidEndpoint(test.endpoints, test.newEndpoint) - if isValid != test.output { - t.Errorf("endpoints %v, new endpoint %v find %t expected %t", test.endpoints, - test.newEndpoint, isValid, test.output) - } - } -} - -// TestConstructRoutingTree tests constructing routing tree and matching routes properly -func TestConstructRoutingTree(t *testing.T) { - // create gearbox instance - gb := setupGearbox(&Settings{ - CacheSize: 1, - }) - +func TestHandle(t *testing.T) { // testing routes routes := []struct { - method string - path string - handler handlersChain + method string + path string + conflict bool + handlers handlersChain }{ - {method: MethodGet, path: "/articles/search", handler: emptyHandlersChain}, - {method: MethodGet, path: "/articles/test", handler: emptyHandlersChain}, - {method: MethodGet, path: "/articles/204", handler: emptyHandlersChain}, - {method: MethodGet, path: "/posts", handler: emptyHandlersChain}, - {method: MethodGet, path: "/post/502", handler: emptyHandlersChain}, - {method: MethodGet, path: "/post/a23011a", handler: emptyHandlersChain}, - {method: MethodGet, path: "/user/204", handler: emptyHandlersChain}, - {method: MethodGet, path: "/user/205/", handler: emptyHandlersChain}, - {method: MethodPost, path: "/user/204/setting", handler: emptyHandlersChain}, - {method: MethodGet, path: "/users/*", handler: emptyHandlersChain}, - {method: MethodGet, path: "/books/get/:name", handler: emptyHandlersChain}, - {method: MethodGet, path: "/books/get/*", handler: emptyHandlersChain}, - {method: MethodGet, path: "/books/search/:pattern:([a-z]+", handler: emptyHandlersChain}, - {method: MethodGet, path: "/books/search/:pattern", handler: emptyHandlersChain}, - {method: MethodGet, path: "/books/search/:pattern1/:pattern2/:pattern3", handler: emptyHandlersChain}, - {method: MethodGet, path: "/books//search/*", handler: emptyHandlersChain}, - {method: MethodGet, path: "/account/:name?", handler: emptyHandlersChain}, - {method: MethodGet, path: "/profile/:name:([a-z]+)?", handler: emptyHandlersChain}, - {method: MethodGet, path: "/order/:name1/:name2:([a-z]+)?", handler: emptyHandlersChain}, - {method: MethodGet, path: "/", handler: emptyHandlersChain}, - } - - // register routes - for _, r := range routes { - gb.registerRoute(r.method, r.path, r.handler) - } - - gb.constructRoutingTree() - - // requests test cases - requests := []struct { - method string - path string - params map[string]string - match bool - }{ - {method: MethodPut, path: "/admin/welcome", match: false, params: make(map[string]string)}, - {method: MethodGet, path: "/articles/search", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/articles/test", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/articles/test", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/articles/test", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/articles/204", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/posts", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/post/502", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/post/a23011a", match: true, params: make(map[string]string)}, - {method: MethodPost, path: "/post/a23011a", match: false, params: make(map[string]string)}, - {method: MethodGet, path: "/user/204", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/user/205", match: true, params: make(map[string]string)}, - {method: MethodPost, path: "/user/204/setting", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/users/ahmed", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/users/ahmed/ahmed", match: true, params: make(map[string]string)}, - {method: MethodPut, path: "/users/ahmed/ahmed", match: false, params: make(map[string]string)}, - {method: MethodGet, path: "/books/get/test", match: true, params: map[string]string{"name": "test"}}, - {method: MethodGet, path: "/books/search/test", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/books/search//test", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/books/search/123", match: true, params: map[string]string{"pattern": "123"}}, - {method: MethodGet, path: "/books/search/test1/test2/test3", match: true, params: map[string]string{"pattern1": "test1", "pattern2": "test2", "pattern3": "test3"}}, - {method: MethodGet, path: "/books/search/test/test2", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/books/search/test/test2", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/account/testuser", match: true, params: map[string]string{"name": "testuser"}}, - {method: MethodGet, path: "/account", match: true, params: make(map[string]string)}, - {method: MethodPut, path: "/account/test1/test2", match: false, params: make(map[string]string)}, - {method: MethodGet, path: "/profile/testuser", match: true, params: map[string]string{"name": "testuser"}}, - {method: MethodGet, path: "/profile", match: true, params: make(map[string]string)}, - {method: MethodGet, path: "/order/test1", match: true, params: map[string]string{"name1": "test1"}}, - {method: MethodGet, path: "/order/test1/test2/", match: true, params: map[string]string{"name1": "test1", "name2": "test2"}}, - {method: MethodPut, path: "/order/test1/test2/test3", match: false, params: make(map[string]string)}, - {method: MethodGet, path: "/", match: true, params: make(map[string]string)}, - } - - // test matching routes - for _, rq := range requests { - handler, params := gb.matchRoute(rq.method, rq.path) - if (handler != nil && !rq.match) || (handler == nil && rq.match) { - t.Errorf("input %s %s find nil expecting handler", rq.method, rq.path) - } - for paramKey, expectedParam := range rq.params { - if actualParam, ok := params[paramKey]; !ok || actualParam != expectedParam { - if !ok { - actualParam = "nil" - } - for k, w := range params { - fmt.Println(k, string(w)) - } - - t.Errorf("input %s %s parameter %s find %s expecting %s", - rq.method, rq.path, paramKey, actualParam, expectedParam) + {method: MethodGet, path: "/articles/search", conflict: false, handlers: fakeHandlersChain}, + {method: MethodGet, path: "/articles/test", conflict: false, handlers: fakeHandlersChain}, + {method: MethodGet, path: "", conflict: true, handlers: fakeHandlersChain}, + {method: "", path: "/articles/test", conflict: true, handlers: fakeHandlersChain}, + {method: MethodGet, path: "orders/test", conflict: true, handlers: fakeHandlersChain}, + {method: MethodGet, path: "/books/test", conflict: true, handlers: emptyHandlersChain}, + } + + router := &router{ + settings: &Settings{}, + cache: make(map[string]*matchResult), + pool: sync.Pool{ + New: func() interface{} { + return new(context) + }, + }, + } + + for _, route := range routes { + recv := catchPanic(func() { + router.handle(route.method, route.path, route.handlers) + }) + + if route.conflict { + if recv == nil { + t.Errorf("no panic for conflicting route '%s'", route.path) } + } else if recv != nil { + t.Errorf("unexpected panic for route '%s': %v", route.path, recv) } } -} -// TestNullRoutingTree tests matching with null routing tree -func TestNullRoutingTree(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - - // register route - gb.registerRoute(MethodGet, "/*", emptyHandlersChain) - - // test handler is nil - if handler, _ := gb.matchRoute(MethodGet, "/hello/world"); handler != nil { - t.Errorf("input GET /hello/world find handler expecting nil") - } } -// TestMatchAll tests matching all requests with one handler -func TestMatchAll(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - - // register route - gb.registerRoute(MethodGet, "/*", emptyHandlersChain) - gb.constructRoutingTree() - - // test handler is not nil - if handler, _ := gb.matchRoute(MethodGet, "/hello/world"); handler == nil { - t.Errorf("input GET /hello/world find nil expecting handler") - } - - if handler, _ := gb.matchRoute(MethodGet, "//world"); handler == nil { - t.Errorf("input GET //world find nil expecting handler") - } -} - -// TestConstructRoutingTree tests constructing routing tree with two handlers -// for the same path and method -func TestConstructRoutingTreeConflict(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - - // register routes - gb.registerRoute(MethodGet, "/articles/test", emptyHandlersChain) - gb.registerRoute(MethodGet, "/articles/test", emptyHandlersChain) - - if err := gb.constructRoutingTree(); err == nil { - t.Fatalf("invalid listener passed") - } -} - -// TestNoRegisteredFallback tests that if no registered fallback is available -// matchRoute() returns nil -func TestNoRegisteredFallback(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - - // register routes - gb.registerRoute(MethodGet, "/articles", emptyHandlersChain) - gb.constructRoutingTree() - - // attempt to match route that cannot match - if handler, _ := gb.matchRoute(MethodGet, "/fail"); handler != nil { - t.Errorf("input GET /fail found a valid handler, expecting nil") - } -} - -// TestFallback tests that if a registered fallback is available -// matchRoute() returns the non-nil registered fallback handler -func TestFallback(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - - // register routes - gb.registerRoute(MethodGet, "/articles", emptyHandlersChain) - if err := gb.registerFallback(emptyHandlersChain); err != nil { - t.Errorf("invalid fallback: %s", err.Error()) - } - gb.constructRoutingTree() - - // attempt to match route that cannot match - if handler, _ := gb.matchRoute(MethodGet, "/fail"); handler == nil { - t.Errorf("input GET /fail did not find a valid handler, expecting valid fallback handler") +func TestHandler(t *testing.T) { + routes := []struct { + method string + path string + handlers handlersChain + }{ + {method: MethodGet, path: "/articles/search", handlers: fakeHandlersChain}, + {method: MethodGet, path: "/articles/test", handlers: fakeHandlersChain}, } -} - -// TestInvalidFallback tests that a fallback cannot be registered -// with a nil handler -func TestInvalidFallback(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - // attempt to register an invalid (nil) fallback handler - if err := gb.registerFallback(nil); err == nil { - t.Errorf("registering an invalid fallback did not return an error, expected error") + router := &router{ + settings: &Settings{}, + cache: make(map[string]*matchResult), + pool: sync.Pool{ + New: func() interface{} { + return new(context) + }, + }, } -} -// TestGroupRouting tests that you can do group routing -func TestGroupRouting(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - routes := []*Route{gb.Get("/id", emptyHandler), gb.Post("/abc", emptyHandler), gb.Post("/abcd", emptyHandler)} - gb.Group("/account", routes) - // attempt to register an invalid (nil) fallback handler - if err := gb.constructRoutingTree(); err != nil { - t.Errorf("Grout routing failed, error: %v", err) + for _, route := range routes { + router.handle(route.method, route.path, route.handlers) } -} -// TestNestedGroupRouting tests that you can do group routing inside a group routing -func TestNestedGroupRouting(t *testing.T) { - // create gearbox instance - gb := setupGearbox() - routes := []*Route{gb.Get("/id", emptyHandler), gb.Post("/abc", emptyHandler), gb.Post("/abcd", emptyHandler)} - gb.Group("/account", gb.Group("/api", routes)) - // attempt to register an invalid (nil) fallback handler - if err := gb.constructRoutingTree(); err != nil { - t.Errorf("Grout routing failed, error: %v", err) - } } diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..bf7fba7 --- /dev/null +++ b/tree.go @@ -0,0 +1,168 @@ +package gearbox + +import ( + "strings" +) + +type nodeType uint8 + +const ( + static nodeType = iota + root + parama + catchAll +) + +type node struct { + path string + param *node + children map[string]*node + nType nodeType + handlers handlersChain +} + +func (n *node) addRoute(path string, handlers handlersChain) { + currentNode := n + originalPath := path + path = path[1:] + + paramNames := make(map[string]bool) + + for { + pathLen := len(path) + if pathLen == 0 { + if currentNode.handlers != nil { + panic("handlers are already registered for path '" + originalPath + "'") + } + currentNode.handlers = handlers + break + } + + segmentDelimiter := strings.Index(path, "/") + if segmentDelimiter == -1 { + segmentDelimiter = pathLen + } + + pathSegment := path[:segmentDelimiter] + if pathSegment[0] == ':' || pathSegment[0] == '*' { + // Parameter + if len(currentNode.children) > 0 { + panic("parameter " + pathSegment + + " conflicts with existing static children in path '" + + originalPath + "'") + } + + if currentNode.param != nil { + if currentNode.param.path[0] == '*' { + panic("parameter " + pathSegment + + " conflicts with catch all (*) route in path '" + + originalPath + "'") + } else if currentNode.param.path != pathSegment { + panic("parameter " + pathSegment + " in new path '" + + originalPath + "' conflicts with existing wildcard '" + + currentNode.param.path) + } + } + + if currentNode.param == nil { + var nType nodeType + if pathSegment[0] == '*' { + nType = catchAll + if pathLen > 1 { + panic("catch all (*) routes are only allowed " + + "at the end of the path in path '" + + originalPath + "'") + } + } else { + nType = parama + if _, ok := paramNames[pathSegment]; ok { + panic("parameter " + pathSegment + + " must be unique in path '" + originalPath + "'") + } else { + paramNames[pathSegment] = true + } + } + + currentNode.param = &node{ + path: pathSegment, + nType: nType, + children: make(map[string]*node), + } + } + currentNode = currentNode.param + } else { + // Static + if currentNode.param != nil { + panic(pathSegment + "' conflicts with existing parameter " + + currentNode.param.path + " in path '" + originalPath + "'") + } + if child, ok := currentNode.children[pathSegment]; ok { + currentNode = child + + } else { + child = &node{ + path: pathSegment, + nType: static, + children: make(map[string]*node), + } + currentNode.children[pathSegment] = child + currentNode = child + } + } + + if pathLen > segmentDelimiter { + segmentDelimiter++ + } + path = path[segmentDelimiter:] + } +} + +func (n *node) matchRoute(path string, ctx *context) handlersChain { + pathLen := len(path) + if pathLen > 0 && path[0] != '/' { + return nil + } + + currentNode := n + path = path[1:] + + for { + pathLen = len(path) + + if pathLen == 0 || currentNode.nType == catchAll { + return currentNode.handlers + } + segmentDelimiter := strings.Index(path, "/") + if segmentDelimiter == -1 { + segmentDelimiter = pathLen + } + pathSegment := path[:segmentDelimiter] + + if pathLen > segmentDelimiter { + segmentDelimiter++ + } + path = path[segmentDelimiter:] + + if currentNode.param != nil { + currentNode = currentNode.param + ctx.paramValues[currentNode.path[1:]] = pathSegment + continue + } + + if child, ok := currentNode.children[pathSegment]; ok { + currentNode = child + continue + } + + return nil + } +} + +// createRootNode creates an instance of node with root type +func createRootNode() *node { + return &node{ + nType: root, + path: "/", + children: make(map[string]*node), + } +} diff --git a/tree_test.go b/tree_test.go new file mode 100644 index 0000000..5aa6a49 --- /dev/null +++ b/tree_test.go @@ -0,0 +1,133 @@ +package gearbox + +import "testing" + +func catchPanic(f func()) (recv interface{}) { + defer func() { + recv = recover() + }() + + f() + return +} + +type testRoute struct { + path string + conflict bool +} + +func TestAddRoute(t *testing.T) { + tree := createRootNode() + + routes := []testRoute{ + {"/cmd/:tool/:sub", false}, + {"/cmd/vet", true}, + {"/src/*", false}, + {"/src/*", true}, + {"/src/test", true}, + {"/src/:test", true}, + {"/src/", false}, + {"/src1/", false}, + {"/src1/*", false}, + {"/search/:query", false}, + {"/search/invalid", true}, + {"/user_:name", false}, + {"/user_x", false}, + {"/id:id", false}, + {"/id/:id", false}, + {"/id/:value", true}, + {"/id/:id/settings", false}, + {"/id/:id/:type", true}, + {"/*", true}, + {"books/*/get", true}, + {"/file/test", false}, + {"/file/test", true}, + {"/file/:test", true}, + {"/orders/:id/settings/:id", true}, + {"/accounts/*/settings", true}, + {"/results/*", false}, + {"/results/*/view", true}, + } + for _, route := range routes { + recv := catchPanic(func() { + tree.addRoute(route.path, emptyHandlersChain) + }) + + if route.conflict { + if recv == nil { + t.Errorf("no panic for conflicting route '%s'", route.path) + } + } else if recv != nil { + t.Errorf("unexpected panic for route '%s': %v", route.path, recv) + } + } +} + +type testRequests []struct { + path string + match bool + params map[string]string +} + +func TestMatchRoute(t *testing.T) { + tree := createRootNode() + + routes := [...]string{ + "/hi", + "/contact", + "/users/:id/", + "/books/*", + "/search/:item1/settings/:item2", + "/co", + "/c", + "/a", + "/ab", + "/doc/", + "/doc/go_faq.html", + "/doc/go1.html", + "/α", + "/β", + } + for _, route := range routes { + tree.addRoute(route, emptyHandlersChain) + } + + requests := testRequests{ + {"/a", true, nil}, + {"/", false, nil}, + {"/hi", true, nil}, + {"/contact", true, nil}, + {"/co", true, nil}, + {"/con", false, nil}, // key mismatch + {"/cona", false, nil}, // key mismatch + {"/no", false, nil}, // no matching child + {"/ab", true, nil}, + {"/α", true, nil}, + {"/β", true, nil}, + {"/users/test", true, map[string]string{"id": "test"}}, + {"/books/title", true, nil}, + {"/search/test1/settings/test2", true, map[string]string{"item1": "test1", "item2": "test2"}}, + {"/search/test1", false, nil}, + {"test", false, nil}, + } + for _, request := range requests { + ctx := &context{paramValues: make(map[string]string)} + handler := tree.matchRoute(request.path, ctx) + + if handler == nil { + if request.match { + t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path) + } + } else if !request.match { + t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path) + } + + for expectedKey, expectedValue := range request.params { + actualValue := ctx.Param(expectedKey) + if actualValue != expectedValue { + t.Errorf(" mismatch for route '%s' parameter '%s' actual '%s', expected '%s'", + request.path, expectedKey, actualValue, expectedValue) + } + } + } +}