From fc388d28ae137dc1e5271dc6df5ad30cf30e5c98 Mon Sep 17 00:00:00 2001
From: Abdelrahman Ahmed
Date: Sun, 26 Jul 2020 14:16:55 +0200
Subject: [PATCH] Improve routing (#64)
* improve routing
* update go.sum
* fix allowed method calls
* update readme
* allow cache for post and get methods
* support serving static
* fix prefix lenght value
* improve context comments
* improve comments
* update readme
---
.github/workflows/test_and_build.yml | 2 +-
README.md | 80 ++--
cache.go | 85 -----
cache_test.go | 42 ---
context.go | 75 +++-
context_test.go | 6 +-
gearbox.go | 234 ++++++++----
gearbox_test.go | 322 ++++++++++++----
go.sum | 8 +-
router.go | 534 ++++++++-------------------
router_test.go | 523 +++-----------------------
tree.go | 168 +++++++++
tree_test.go | 133 +++++++
13 files changed, 1021 insertions(+), 1191 deletions(-)
delete mode 100644 cache.go
delete mode 100644 cache_test.go
create mode 100644 tree.go
create mode 100644 tree_test.go
diff --git a/.github/workflows/test_and_build.yml b/.github/workflows/test_and_build.yml
index 7130b80..c3900cf 100644
--- a/.github/workflows/test_and_build.yml
+++ b/.github/workflows/test_and_build.yml
@@ -43,7 +43,7 @@ jobs:
- name: Build
run: go build
-
+
- name: Lint
run: golint -set_exit_status ./...
diff --git a/README.md b/README.md
index c60903f..3b97378 100644
--- a/README.md
+++ b/README.md
@@ -20,17 +20,13 @@
-**gearbox** :gear: is a web framework for building micro services written in Go with a focus on high performance and memory optimization.
+**gearbox** :gear: is a web framework for building micro services written in Go with a focus on high performance and memory optimization. It's built on [fasthttp](https://github.com/valyala/fasthttp) which is **up to 10x faster** than net/http
-Currently, **gearbox** :gear: is **under development (not production ready)** and built on [fasthttp](https://github.com/valyala/fasthttp) which is **up to 10x faster** than net/http
-
-In **gearbox**, we care about peformance and memory which will be used by each method while building things up and how we can improve that. It also takes more time to **research** about each component that will be used and **compare** it with different implementations of other open source web frameworks. It may end up writing our **own components** in an optimized way to achieve our goals
### gearbox seeks to be
+ Secure :closed_lock_with_key:
+ Fast :rocket:
-+ Simple :eyeglasses:
-+ Easy to use
++ Easy to use :eyeglasses:
+ Lightweight
@@ -58,8 +54,8 @@ func main() {
gb := gearbox.New()
// Define your handlers
- gb.Get("/hello", func(ctx *gearbox.Context) {
- ctx.RequestCtx.Response.SetBodyString("Hello World!")
+ gb.Get("/hello", func(ctx gearbox.Context) {
+ ctx.SendString("Hello World!")
})
// Start service
@@ -80,18 +76,8 @@ func main() {
gb := gearbox.New()
// Handler with parameter
- gb.Get("/users/:user", func(ctx *gearbox.Context) {
- fmt.Printf("%s\n", ctx.Params["user"])
- })
-
- // Handler with optional parameter
- gb.Get("/search/:pattern?", func(ctx *gearbox.Context) {
- fmt.Printf("%s\n", ctx.Params["pattern"])
- })
-
- // Handler with regex parameter
- gb.Get("/book/:name:([a-z]+[0-3])", func(ctx *gearbox.Context) {
- fmt.Printf("%s\n", ctx.Params["name"])
+ gb.Get("/users/:user", func(ctx gearbox.Context) {
+ ctx.SendString(ctx.Param("user"))
})
// Start service
@@ -113,33 +99,32 @@ func main() {
gb := gearbox.New()
// create a logger middleware
- logMiddleware := func(ctx *gearbox.Context) {
- log.Printf(ctx.RequestCtx.String())
+ logMiddleware := func(ctx gearbox.Context) {
+ log.Printf("log message!")
ctx.Next() // Next is what allows the request to continue to the next middleware/handler
}
// create an unauthorized middleware
- unAuthorizedMiddleware := func(ctx *gearbox.Context) {
- ctx.RequestCtx.SetStatusCode(401) // unauthorized status code
- ctx.RequestCtx.Response.SetBodyString("You are unauthorized to access this page!")
+ unAuthorizedMiddleware := func(ctx gearbox.Context) {
+ ctx.Status(gearbox.StatusUnauthorized).SendString("You are unauthorized to access this page!")
}
// Register the log middleware for all requests
gb.Use(logMiddleware)
// Define your handlers
- gb.Get("/hello", func(ctx *gearbox.Context) {
- ctx.RequestCtx.Response.SetBodyString("Hello World!")
+ gb.Get("/hello", func(ctx gearbox.Context) {
+ ctx.SendString("Hello World!")
})
-
+
// Register the routes to be used when grouping routes
- routes := []*gearbox.Route {
- gb.Get("/id", func(ctx *gearbox.Context) {
- ctx.RequestCtx.Response.SetBodyString("User X")
+ routes := []*gearbox.Route{
+ gb.Get("/id", func(ctx gearbox.Context) {
+ ctx.SendString("User X")
+ }),
+ gb.Delete("/id", func(ctx gearbox.Context) {
+ ctx.SendString("Deleted")
}),
- gb.Delete("/id", func(ctx *gearbox.Context) {
- ctx.RequestCtx.Response.SetBodyString("Deleted")
- })
}
// Group account routes
@@ -150,8 +135,8 @@ func main() {
// Define a route with unAuthorizedMiddleware as the middleware
// you can define as many middlewares as you want and have the handler as the last argument
- gb.Get("/protected", unAuthorizedMiddleware, func(ctx *gearbox.Context) {
- ctx.RequestCtx.Response.SetBodyString("You accessed a protected page")
+ gb.Get("/protected", unAuthorizedMiddleware, func(ctx gearbox.Context) {
+ ctx.SendString("You accessed a protected page")
})
// Start service
@@ -159,6 +144,29 @@ func main() {
}
```
+#### Static Files
+
+```go
+package main
+
+import (
+ "github.com/gogearbox/gearbox"
+)
+
+func main() {
+ // Setup gearbox
+ gb := gearbox.New()
+
+ // Serve files in assets directory for prefix static
+ // for example /static/gearbox.png, etc.
+ gb.Static("/static", "./assets")
+
+ // Start service
+ gb.Start(":3000")
+}
+```
+
+
### Contribute & Support
+ Add a [GitHub Star](https://github.com/gogearbox/gearbox/stargazers)
+ [Suggest new features, ideas and optimizations](https://github.com/gogearbox/gearbox/issues)
diff --git a/cache.go b/cache.go
deleted file mode 100644
index 4345eb1..0000000
--- a/cache.go
+++ /dev/null
@@ -1,85 +0,0 @@
-package gearbox
-
-import (
- "container/list"
- "sync"
-)
-
-// Implementation of LRU caching using doubly linked list and tst
-
-// Cache returns LRU cache
-type Cache interface {
- Set(key string, value interface{})
- Get(key string) interface{}
-}
-
-// lruCache holds info used for caching internally
-type lruCache struct {
- capacity int
- list *list.List
- store map[string]interface{}
- mutex sync.RWMutex
-}
-
-// pair contains key and value of element
-type pair struct {
- key string
- value interface{}
-}
-
-// NewCache returns LRU cache
-func NewCache(capacity int) Cache {
- // minimum is 1
- if capacity <= 0 {
- capacity = 1
- }
-
- return &lruCache{
- capacity: capacity,
- list: new(list.List),
- store: make(map[string]interface{}),
- }
-}
-
-// Get returns value of provided key if it's existing
-func (c *lruCache) Get(key string) interface{} {
- c.mutex.RLock()
- defer c.mutex.RUnlock()
-
- // check if list node exists
- if node, ok := c.store[key].(*list.Element); ok {
- c.list.MoveToFront(node)
-
- return node.Value.(*pair).value
- }
- return nil
-}
-
-// Set adds a value to provided key in cache
-func (c *lruCache) Set(key string, value interface{}) {
- c.mutex.Lock()
- defer c.mutex.Unlock()
-
- // update the value if key is existing
- if node, ok := c.store[key].(*list.Element); ok {
- c.list.MoveToFront(node)
-
- node.Value.(*pair).value = value
- return
- }
-
- // remove last node if cache is full
- if c.list.Len() == c.capacity {
- lastNode := c.list.Back()
-
- // delete key's value
- delete(c.store, lastNode.Value.(*pair).key)
-
- c.list.Remove(lastNode)
- }
-
- c.store[key] = c.list.PushFront(&pair{
- key: key,
- value: value,
- })
-}
diff --git a/cache_test.go b/cache_test.go
deleted file mode 100644
index c5d0d50..0000000
--- a/cache_test.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package gearbox
-
-import (
- "fmt"
-)
-
-// ExampleNewCache tests Cache set and get methods
-func ExampleNewCache() {
- cache := NewCache(3)
- cache.Set("user1", 1)
- fmt.Println(cache.Get("user1").(int))
-
- cache.Set("user2", 2)
- fmt.Println(cache.Get("user2").(int))
-
- cache.Set("user3", 3)
- fmt.Println(cache.Get("user3").(int))
-
- cache.Set("user4", 4)
- fmt.Println(cache.Get("user1"))
- fmt.Println(cache.Get("user2").(int))
-
- cache.Set("user5", 5)
- fmt.Println(cache.Get("user3"))
-
- cache.Set("user5", 6)
- fmt.Println(cache.Get("user5").(int))
-
- cache2 := NewCache(0)
- cache2.Set("user1", 1)
- fmt.Println(cache2.Get("user1").(int))
-
- // Output:
- // 1
- // 2
- // 3
- //
- // 2
- //
- // 6
- // 1
-}
diff --git a/context.go b/context.go
index dcb6ecc..7738014 100644
--- a/context.go
+++ b/context.go
@@ -4,23 +4,80 @@ import (
"github.com/valyala/fasthttp"
)
+// Context interface
+type Context interface {
+ Next()
+ Context() *fasthttp.RequestCtx
+ Param(key string) string
+ Query(key string) string
+ SendString(value string) Context
+ Status(status int) Context
+ Set(key string, value string)
+ Get(key string) string
+ Body() string
+}
+
// handlerFunc defines the handler used by middleware as return value.
-type handlerFunc func(ctx *Context)
+type handlerFunc func(ctx Context)
// handlersChain defines a handlerFunc array.
type handlersChain []handlerFunc
// Context defines the current context of request and handlers/middlewares to execute
-type Context struct {
- RequestCtx *fasthttp.RequestCtx
- Params map[string]string
- handlers handlersChain
- index int
+type context struct {
+ requestCtx *fasthttp.RequestCtx
+ paramValues map[string]string
+ handlers handlersChain
+ index int
}
// Next function is used to successfully pass from current middleware to next middleware.
-// if the middleware thinks it's okay to pass it.
-func (ctx *Context) Next() {
+// if the middleware thinks it's okay to pass it
+func (ctx *context) Next() {
ctx.index++
- ctx.handlers[ctx.index](ctx)
+ if ctx.index < len(ctx.handlers) {
+ ctx.handlers[ctx.index](ctx)
+ }
+}
+
+// Param returns value of path parameter specified by key
+func (ctx *context) Param(key string) string {
+ return ctx.paramValues[key]
+}
+
+// Context returns Fasthttp context
+func (ctx *context) Context() *fasthttp.RequestCtx {
+ return ctx.requestCtx
+}
+
+// SendString sends body of response as a string
+func (ctx *context) SendString(value string) Context {
+ ctx.requestCtx.SetBodyString(value)
+ return ctx
+}
+
+// Status sets the HTTP status code
+func (ctx *context) Status(status int) Context {
+ ctx.requestCtx.Response.SetStatusCode(status)
+ return ctx
+}
+
+// Get returns the HTTP request header specified by field key
+func (ctx *context) Get(key string) string {
+ return GetString(ctx.requestCtx.Request.Header.Peek(key))
+}
+
+// Set sets the response's HTTP header field key to the specified key, value
+func (ctx *context) Set(key, value string) {
+ ctx.requestCtx.Response.Header.Set(key, value)
+}
+
+// Query returns the query string parameter in the request url
+func (ctx *context) Query(key string) string {
+ return GetString(ctx.requestCtx.QueryArgs().Peek(key))
+}
+
+// Body contains the raw body submitted in a POST request
+func (ctx *context) Body() string {
+ return GetString(ctx.requestCtx.Request.Body())
}
diff --git a/context_test.go b/context_test.go
index 8882506..13a9289 100644
--- a/context_test.go
+++ b/context_test.go
@@ -6,7 +6,7 @@ import (
)
// Test passing the request from middleware to handler
-func Test_Next(t *testing.T) {
+func TestNext(t *testing.T) {
// testing routes
routes := []struct {
path string
@@ -17,9 +17,7 @@ func Test_Next(t *testing.T) {
}
// get instance of gearbox
- gb := new(gearbox)
- gb.registeredRoutes = make([]*Route, 0)
- gb.settings = &Settings{}
+ gb := setupGearbox()
// register routes according to method
for _, r := range routes {
diff --git a/gearbox.go b/gearbox.go
index bfb2c40..0cc0bbe 100644
--- a/gearbox.go
+++ b/gearbox.go
@@ -2,10 +2,11 @@
package gearbox
import (
- "fmt"
"log"
"net"
"os"
+ "strings"
+ "sync"
"time"
"github.com/valyala/fasthttp"
@@ -14,20 +15,17 @@ import (
// Exported constants
const (
- Version = "1.0.3" // Version of gearbox
+ Version = "1.1.0" // Version of gearbox
Name = "Gearbox" // Name of gearbox
// http://patorjk.com/software/taag/#p=display&f=Big%20Money-ne&t=Gearbox
banner = `
- /$$$$$$ /$$
- /$$__ $$ | $$
-| $$ \__/ /$$$$$$ /$$$$$$ /$$$$$$ | $$$$$$$ /$$$$$$ /$$ /$$
-| $$ /$$$$ /$$__ $$ |____ $$ /$$__ $$| $$__ $$ /$$__ $$| $$ /$$/
-| $$|_ $$| $$$$$$$$ /$$$$$$$| $$ \__/| $$ \ $$| $$ \ $$ \ $$$$/
-| $$ \ $$| $$_____/ /$$__ $$| $$ | $$ | $$| $$ | $$ >$$ $$
-| $$$$$$/| $$$$$$$| $$$$$$$| $$ | $$$$$$$/| $$$$$$/ /$$/\ $$
- \______/ \_______/ \_______/|__/ |_______/ \______/ |__/ \__/ %s
-Listening on %s
-`
+ _____ _
+ / ____| | |
+| | __ ___ __ _ _ __ | |__ ___ __ __
+| | |_ | / _ \ / _' || '__|| '_ \ / _ \\ \/ /
+| |__| || __/| (_| || | | |_) || (_) |> <
+ \_____| \___| \__,_||_| |_.__/ \___//_/\_\ v%s
+Listening on %s`
)
const (
@@ -141,32 +139,44 @@ type Gearbox interface {
Connect(path string, handlers ...handlerFunc) *Route
Options(path string, handlers ...handlerFunc) *Route
Trace(path string, handlers ...handlerFunc) *Route
- Method(method, path string, handlers ...handlerFunc) *Route
- Fallback(handlers ...handlerFunc) error
- Use(middlewares ...handlerFunc)
Group(prefix string, routes []*Route) []*Route
+ Static(prefix, root string)
+ NotFound(handlers ...handlerFunc)
+ Use(middlewares ...handlerFunc)
}
// gearbox implements Gearbox interface
type gearbox struct {
- httpServer *fasthttp.Server
- routingTreeRoot *routeNode
- registeredRoutes []*Route
- address string // server address
- handlers handlersChain
- registeredFallback *routerFallback
- cache Cache
- settings *Settings
+ httpServer *fasthttp.Server
+ router *router
+ registeredRoutes []*Route
+ address string // server address
+ middlewares handlersChain
+ settings *Settings
}
// Settings struct holds server settings
type Settings struct {
// Enable case sensitive routing
- CaseSensitive bool // default false
+ CaseInSensitive bool // default false
// Maximum size of LRU cache that will be used in routing if it's enabled
CacheSize int // default 1000
+ // Enables answering with HTTP status code 405 if request does not match
+ // with any route, but there are another methods are allowed for that route
+ // otherwise answer with Not Found handlers or status code 404.
+ HandleMethodNotAllowed bool // default false
+
+ // Enables automatic replies to OPTIONS requests if there are no handlers
+ // registered for that route
+ HandleOPTIONS bool // default false
+
+ // Enables automatic recovering from panic while executing handlers by
+ // answering with HTTP status code 500 and logging error message without
+ // stopping service
+ AutoRecover bool // default false
+
// ServerName for sending in response headers
ServerName string // default ""
@@ -262,6 +272,17 @@ func New(settings ...*Settings) Gearbox {
gb.settings.Concurrency = defaultConcurrency
}
+ // Initialize router
+ gb.router = &router{
+ settings: gb.settings,
+ cache: make(map[string]*matchResult),
+ pool: sync.Pool{
+ New: func() interface{} {
+ return new(context)
+ },
+ },
+ }
+
gb.httpServer = gb.newHTTPServer()
return gb
@@ -269,12 +290,8 @@ func New(settings ...*Settings) Gearbox {
// Start handling requests
func (gb *gearbox) Start(address string) error {
- // Construct routing tree
- if err := gb.constructRoutingTree(); err != nil {
- return fmt.Errorf("unable to construct routing %s", err.Error())
- }
-
- gb.cache = NewCache(gb.settings.CacheSize)
+ // Setup router
+ gb.setupRouter()
if gb.settings.Prefork {
if !gb.settings.DisableStartupMessage {
@@ -311,12 +328,13 @@ func (gb *gearbox) Start(address string) error {
type customLogger struct{}
func (dl *customLogger) Printf(format string, args ...interface{}) {
+ //log.Printf(format)
}
// newHTTPServer returns a new instance of fasthttp server
func (gb *gearbox) newHTTPServer() *fasthttp.Server {
return &fasthttp.Server{
- Handler: gb.handler,
+ Handler: gb.router.Handler,
Logger: &customLogger{},
LogAllErrors: false,
Name: gb.settings.ServerName,
@@ -332,6 +350,34 @@ func (gb *gearbox) newHTTPServer() *fasthttp.Server {
}
}
+// registerRoute registers handlers with method and path
+func (gb *gearbox) registerRoute(method, path string, handlers handlersChain) *Route {
+ if gb.settings.CaseInSensitive {
+ path = strings.ToLower(path)
+ }
+
+ route := &Route{
+ Path: path,
+ Method: method,
+ Handlers: handlers,
+ }
+
+ // Add route to registered routes
+ gb.registeredRoutes = append(gb.registeredRoutes, route)
+ return route
+}
+
+// setupRouter initializes router with registered routes
+func (gb *gearbox) setupRouter() {
+ for _, route := range gb.registeredRoutes {
+ gb.router.handle(route.Method, route.Path, append(gb.middlewares, route.Handlers...))
+ }
+
+ // Frees intermediate stores after initializing router
+ gb.registeredRoutes = nil
+ gb.middlewares = nil
+}
+
// Stop serving
func (gb *gearbox) Stop() error {
err := gb.httpServer.Shutdown()
@@ -347,68 +393,50 @@ func (gb *gearbox) Stop() error {
// Get registers an http relevant method
func (gb *gearbox) Get(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodGet), string(path), handlers)
+ return gb.registerRoute(MethodGet, path, handlers)
}
// Head registers an http relevant method
func (gb *gearbox) Head(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodHead), string(path), handlers)
+ return gb.registerRoute(MethodHead, path, handlers)
}
// Post registers an http relevant method
func (gb *gearbox) Post(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodPost), string(path), handlers)
+ return gb.registerRoute(MethodPost, path, handlers)
}
// Put registers an http relevant method
func (gb *gearbox) Put(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodPut), string(path), handlers)
+ return gb.registerRoute(MethodPut, path, handlers)
}
// Patch registers an http relevant method
func (gb *gearbox) Patch(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodPatch), string(path), handlers)
+ return gb.registerRoute(MethodPatch, path, handlers)
}
// Delete registers an http relevant method
func (gb *gearbox) Delete(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodDelete), string(path), handlers)
+ return gb.registerRoute(MethodDelete, path, handlers)
}
// Connect registers an http relevant method
func (gb *gearbox) Connect(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodConnect), string(path), handlers)
+ return gb.registerRoute(MethodConnect, path, handlers)
}
// Options registers an http relevant method
func (gb *gearbox) Options(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodOptions), string(path), handlers)
+ return gb.registerRoute(MethodOptions, path, handlers)
}
// Trace registers an http relevant method
func (gb *gearbox) Trace(path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(MethodTrace), string(path), handlers)
-}
-
-// Trace registers an http relevant method
-func (gb *gearbox) Method(method, path string, handlers ...handlerFunc) *Route {
- return gb.registerRoute(string(method), string(path), handlers)
-}
-
-// Fallback registers an http handler only fired when no other routes match with request
-func (gb *gearbox) Fallback(handlers ...handlerFunc) error {
- return gb.registerFallback(handlers)
+ return gb.registerRoute(MethodTrace, path, handlers)
}
-// Use attaches a global middleware to the gearbox object.
-// included in the handlers chain for all matched requests.
-// it will always be executed before the handler and/or middlewares for the matched request
-// For example, this is the right place for a logger or some security check or permission checking.
-func (gb *gearbox) Use(middlewares ...handlerFunc) {
- gb.handlers = append(gb.handlers, middlewares...)
-}
-
-// Group appends a prefix to registered routes.
+// Group appends a prefix to registered routes
func (gb *gearbox) Group(prefix string, routes []*Route) []*Route {
for _, route := range routes {
route.Path = prefix + route.Path
@@ -416,25 +444,85 @@ func (gb *gearbox) Group(prefix string, routes []*Route) []*Route {
return routes
}
-// Handles all incoming requests and route them to proper handler according to
-// method and path
-func (gb *gearbox) handler(ctx *fasthttp.RequestCtx) {
- if handlers, params := gb.matchRoute(
- GetString(ctx.Request.Header.Method()),
- GetString(ctx.URI().Path())); handlers != nil {
- context := Context{
- RequestCtx: ctx,
- Params: params,
- handlers: append(gb.handlers, handlers...),
- index: 0,
+func (gb *gearbox) Static(prefix, root string) {
+ if gb.settings.CaseInSensitive {
+ prefix = strings.ToLower(prefix)
+ }
+
+ // remove trailing slash
+ if len(root) > 1 && root[len(root)-1] == '/' {
+ root = root[:len(root)-1]
+ }
+
+ if len(prefix) > 1 && prefix[len(prefix)-1] == '/' {
+ prefix = prefix[:len(prefix)-1]
+ }
+
+ fs := &fasthttp.FS{
+ Root: root,
+ IndexNames: []string{"index.html"},
+ PathRewrite: func(ctx *fasthttp.RequestCtx) []byte {
+ path := ctx.Path()
+
+ if len(path) >= len(prefix) {
+ path = path[len(prefix):]
+ }
+
+ if len(path) > 0 && path[0] != '/' {
+ path = append([]byte("/"), path...)
+ } else if len(path) == 0 {
+ path = []byte("/")
+ }
+ return path
+ },
+ }
+
+ fileHandler := fs.NewRequestHandler()
+ handler := func(ctx Context) {
+ fctx := ctx.Context()
+
+ fileHandler(fctx)
+
+ status := fctx.Response.StatusCode()
+ if status != StatusNotFound && status != StatusForbidden {
+ return
}
- context.handlers[0](&context)
- return
+
+ // Pass to custom not found handlers if there are
+ if gb.router.notFound != nil {
+ gb.router.notFound[0](ctx)
+ return
+ }
+
+ // Default Not Found response
+ fctx.Error(fasthttp.StatusMessage(fasthttp.StatusNotFound),
+ fasthttp.StatusNotFound)
}
- ctx.SetStatusCode(StatusNotFound)
+ // TODO: Improve
+ gb.Get(prefix, handler)
+
+ if len(prefix) > 1 && prefix[len(prefix)-1] != '*' {
+ gb.Get(prefix+"/*", handler)
+ }
+}
+
+// NotFound registers an http handlers that will be called when no other routes
+// match with request
+func (gb *gearbox) NotFound(handlers ...handlerFunc) {
+ gb.router.SetNotFound(handlers)
+}
+
+// Use attaches a global middleware to the gearbox object.
+// included in the handlers chain for all matched requests.
+// it will always be executed before the handler and/or middlewares for the matched request
+// For example, this is the right place for a logger or some security check or permission checking.
+func (gb *gearbox) Use(middlewares ...handlerFunc) {
+ gb.middlewares = append(gb.middlewares, middlewares...)
}
+// printStartupMessage prints gearbox info log message in parent process
+// and prints process id for child process
func printStartupMessage(addr string) {
if prefork.IsChild() {
log.Printf("Started child proc #%v\n", os.Getpid())
diff --git a/gearbox_test.go b/gearbox_test.go
index 797ac77..2fa4ed2 100644
--- a/gearbox_test.go
+++ b/gearbox_test.go
@@ -8,6 +8,9 @@ import (
"net"
"net/http"
"net/http/httputil"
+ "strconv"
+ "strings"
+ "sync"
"testing"
"time"
@@ -37,17 +40,48 @@ func (c *fakeConn) Write(b []byte) (int, error) {
return c.w.Write(b)
}
+// setupGearbox returns instace of gearbox struct
+func setupGearbox(settings ...*Settings) *gearbox {
+ gb := new(gearbox)
+ gb.registeredRoutes = make([]*Route, 0)
+
+ if len(settings) > 0 {
+ gb.settings = settings[0]
+ } else {
+ gb.settings = &Settings{}
+ }
+
+ gb.router = &router{
+ settings: gb.settings,
+ cache: make(map[string]*matchResult),
+ pool: sync.Pool{
+ New: func() interface{} {
+ return new(context)
+ },
+ },
+ }
+
+ return gb
+}
+
// startGearbox constructs routing tree and creates server
func startGearbox(gb *gearbox) {
- gb.cache = NewCache(defaultCacheSize)
- gb.constructRoutingTree()
+ gb.setupRouter()
gb.httpServer = &fasthttp.Server{
- Handler: gb.handler,
- Logger: nil,
+ Handler: gb.router.Handler,
+ Logger: &customLogger{},
LogAllErrors: false,
}
}
+// emptyHandler just an empty handler
+var emptyHandler = func(ctx Context) {}
+
+// empty Handlers chain is just an empty array
+var emptyHandlersChain = handlersChain{}
+
+var fakeHandlersChain = handlersChain{emptyHandler}
+
// makeRequest makes an http request to http server and returns response or error
func makeRequest(request *http.Request, gb *gearbox) (*http.Response, error) {
// Dump request to send it
@@ -78,36 +112,56 @@ func makeRequest(request *http.Request, gb *gearbox) (*http.Response, error) {
if err != nil {
return nil, err
}
-
return resp, nil
}
// handler just an empty handler
-var handler = func(ctx *Context) {}
+var handler = func(ctx Context) {}
+
+// errorHandler contains buggy code
+var errorHandler = func(ctx Context) {
+ m := make(map[string]int)
+ m["a"] = 0
+ ctx.SendString(string(5 / m["a"]))
+}
+
+// headerHandler echos header's value of key "my-header"
+var headerHandler = func(ctx Context) {
+ ctx.Set("custom", ctx.Get("my-header"))
+}
+
+// queryHandler answers with query's value of key "name"
+var queryHandler = func(ctx Context) {
+ ctx.SendString(ctx.Query("name"))
+}
+
+// bodyHandler answers with request body
+var bodyHandler = func(ctx Context) {
+ ctx.Context().Response.SetBodyString(ctx.Body())
+}
// unAuthorizedHandler sets status unauthorized in response
-var unAuthorizedHandler = func(ctx *Context) {
- ctx.RequestCtx.SetStatusCode(StatusUnauthorized)
+var unAuthorizedHandler = func(ctx Context) {
+ ctx.Status(StatusUnauthorized)
}
// pingHandler returns string pong in response body
-var pingHandler = func(ctx *Context) {
- ctx.RequestCtx.Response.SetBodyString("pong")
+var pingHandler = func(ctx Context) {
+ ctx.SendString("pong")
}
// fallbackHandler returns not found status with custom fallback handler in response body
-var fallbackHandler = func(ctx *Context) {
- ctx.RequestCtx.SetStatusCode(StatusNotFound)
- ctx.RequestCtx.Response.SetBodyString("custom fallback handler")
+var fallbackHandler = func(ctx Context) {
+ ctx.Status(StatusNotFound).SendString("custom fallback handler")
}
// emptyMiddleware does not stop the request and passes it to the next middleware/handler
-var emptyMiddleware = func(ctx *Context) {
+var emptyMiddleware = func(ctx Context) {
ctx.Next()
}
// registerRoute matches with register route request with available methods and calls it
-func registerRoute(gb Gearbox, method, path string, handler func(ctx *Context)) {
+func registerRoute(gb Gearbox, method, path string, handler func(ctx Context)) {
switch method {
case MethodGet:
gb.Get(path, handler)
@@ -127,8 +181,6 @@ func registerRoute(gb Gearbox, method, path string, handler func(ctx *Context))
gb.Options(path, handler)
case MethodTrace:
gb.Trace(path, handler)
- default:
- gb.Method(method, path, handler)
}
}
@@ -139,9 +191,13 @@ func TestMethods(t *testing.T) {
routes := []struct {
method string
path string
- handler func(ctx *Context)
+ handler func(ctx Context)
}{
+ {method: MethodGet, path: "/order/get", handler: queryHandler},
+ {method: MethodPost, path: "/order/add", handler: bodyHandler},
+ {method: MethodGet, path: "/books/find", handler: emptyHandler},
{method: MethodGet, path: "/articles/search", handler: emptyHandler},
+ {method: MethodPut, path: "/articles/search", handler: emptyHandler},
{method: MethodHead, path: "/articles/test", handler: emptyHandler},
{method: MethodPost, path: "/articles/204", handler: emptyHandler},
{method: MethodPost, path: "/articles/205", handler: unAuthorizedHandler},
@@ -149,17 +205,17 @@ func TestMethods(t *testing.T) {
{method: MethodPut, path: "/posts", handler: emptyHandler},
{method: MethodPatch, path: "/post/502", handler: emptyHandler},
{method: MethodDelete, path: "/post/a23011a", handler: emptyHandler},
- {method: MethodConnect, path: "/user/204", handler: emptyHandler},
- {method: MethodOptions, path: "/user/204/setting", handler: emptyHandler},
+ {method: MethodConnect, path: "/user/204", handler: headerHandler},
+ {method: MethodOptions, path: "/user/204/setting", handler: errorHandler},
{method: MethodTrace, path: "/users/*", handler: emptyHandler},
- {method: MethodTrace, path: "/users/test", handler: emptyHandler},
- {method: "CUSTOM", path: "/users/test/private", handler: emptyHandler},
}
// get instance of gearbox
gb := setupGearbox(&Settings{
- CaseSensitive: true,
- Prefork: true,
+ CaseInSensitive: true,
+ AutoRecover: true,
+ HandleOPTIONS: true,
+ HandleMethodNotAllowed: true,
})
// register routes according to method
@@ -172,33 +228,44 @@ func TestMethods(t *testing.T) {
// Requests that will be tested
testCases := []struct {
- method string
- path string
- statusCode int
- body string
+ method string
+ path string
+ statusCode int
+ requestBody string
+ body string
+ headers map[string]string
}{
+ {method: MethodGet, path: "/order/get?name=art123", statusCode: StatusOK, body: "art123"},
+ {method: MethodPost, path: "/order/add", requestBody: "testOrder", statusCode: StatusOK, body: "testOrder"},
+ {method: MethodPost, path: "/books/find", statusCode: StatusMethodNotAllowed, body: "Method Not Allowed", headers: map[string]string{"Allow": "GET, OPTIONS"}},
{method: MethodGet, path: "/articles/search", statusCode: StatusOK},
{method: MethodGet, path: "/articles/search", statusCode: StatusOK},
{method: MethodGet, path: "/Articles/search", statusCode: StatusOK},
- {method: MethodGet, path: "/articles/searching", statusCode: StatusNotFound},
+ {method: MethodOptions, path: "/articles/search", statusCode: StatusOK},
+ {method: MethodOptions, path: "*", statusCode: StatusOK},
+ {method: MethodOptions, path: "/*", statusCode: StatusOK},
+ {method: MethodGet, path: "/articles/searching", statusCode: StatusNotFound, body: "Not Found"},
{method: MethodHead, path: "/articles/test", statusCode: StatusOK},
{method: MethodPost, path: "/articles/204", statusCode: StatusOK},
{method: MethodPost, path: "/articles/205", statusCode: StatusUnauthorized},
{method: MethodPost, path: "/Articles/205", statusCode: StatusUnauthorized},
- {method: MethodPost, path: "/articles/206", statusCode: StatusNotFound},
+ {method: MethodPost, path: "/articles/206", statusCode: StatusNotFound, body: "Not Found"},
{method: MethodGet, path: "/ping", statusCode: StatusOK, body: "pong"},
{method: MethodPut, path: "/posts", statusCode: StatusOK},
{method: MethodPatch, path: "/post/502", statusCode: StatusOK},
{method: MethodDelete, path: "/post/a23011a", statusCode: StatusOK},
- {method: MethodConnect, path: "/user/204", statusCode: StatusOK},
- {method: MethodOptions, path: "/user/204/setting", statusCode: StatusOK},
+ {method: MethodConnect, path: "/user/204", statusCode: StatusOK, headers: map[string]string{"custom": "testing"}},
+ {method: MethodOptions, path: "/user/204/setting", statusCode: StatusInternalServerError, body: "Internal Server Error"},
{method: MethodTrace, path: "/users/testing", statusCode: StatusOK},
- {method: "CUSTOM", path: "/users/test/private", statusCode: StatusOK},
}
for _, tc := range testCases {
// create and make http request
- req, _ := http.NewRequest(tc.method, tc.path, nil)
+ req, _ := http.NewRequest(tc.method, tc.path, strings.NewReader(tc.requestBody))
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Add("Content-Length", strconv.Itoa(len(tc.requestBody)))
+ req.Header.Set("my-header", "testing")
+
response, err := makeRequest(req, gb)
if err != nil {
@@ -220,13 +287,73 @@ func TestMethods(t *testing.T) {
if string(body) != tc.body {
t.Fatalf("%s(%s): returned %s expected %s", tc.method, tc.path, body, tc.body)
}
+
+ for expectedKey, expectedValue := range tc.headers {
+ actualValue := response.Header.Get(expectedKey)
+ if actualValue != expectedValue {
+ t.Errorf(" mismatch for route '%s' parameter '%s' actual '%s', expected '%s'",
+ tc.path, expectedKey, actualValue, expectedValue)
+ }
+ }
}
}
-// TestStart tests start service method
-func TestStart(t *testing.T) {
+func TestStatic(t *testing.T) {
+ // get instance of gearbox
+ gb := setupGearbox(&Settings{
+ CaseInSensitive: true,
+ AutoRecover: true,
+ HandleOPTIONS: true,
+ HandleMethodNotAllowed: true,
+ })
+
+ gb.Static("/static/", "./assets/")
+
+ // start serving
+ startGearbox(gb)
+
+ // Requests that will be tested
+ testCases := []struct {
+ method string
+ path string
+ statusCode int
+ body string
+ }{
+ {method: MethodGet, path: "/static/gearbox.png", statusCode: StatusOK},
+ }
+
+ for _, tc := range testCases {
+ // create and make http request
+ req, _ := http.NewRequest(tc.method, tc.path, nil)
+
+ response, err := makeRequest(req, gb)
+
+ if err != nil {
+ t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error())
+ }
+
+ // check status code
+ if response.StatusCode != tc.statusCode {
+ t.Fatalf("%s(%s): returned %d expected %d", tc.method, tc.path, response.StatusCode, tc.statusCode)
+ }
+
+ // read body from response
+ body, err := ioutil.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error())
+ }
+
+ // check response body
+ if tc.body != "" && string(body) != tc.body {
+ t.Fatalf("%s(%s): returned %s expected %s", tc.method, tc.path, body, tc.body)
+ }
+ }
+}
+
+// TestStartWithPrefork tests start service method
+func TestStartWithPrefork(t *testing.T) {
gb := New(&Settings{
- DisableStartupMessage: true,
+ Prefork: true,
})
go func() {
@@ -238,23 +365,35 @@ func TestStart(t *testing.T) {
}
// TestStart tests start service method
+func TestStart(t *testing.T) {
+ gb := New()
+
+ go func() {
+ time.Sleep(1000 * time.Millisecond)
+ gb.Stop()
+ }()
+
+ gb.Start(":3010")
+}
+
+// TestStartWithTLS tests start service method
func TestStartWithTLS(t *testing.T) {
gb := New(&Settings{
- DisableStartupMessage: true,
- TLSKeyPath: "ssl-cert-snakeoil.key",
- TLSCertPath: "ssl-cert-snakeoil.crt",
- TLSEnabled: true,
+ TLSKeyPath: "ssl-cert-snakeoil.key",
+ TLSCertPath: "ssl-cert-snakeoil.crt",
+ TLSEnabled: true,
})
+
// use a channel to hand off the error ( if any )
errs := make(chan error, 1)
go func() {
- time.Sleep(1000 * time.Millisecond)
- _, err := tls.DialWithDialer(&net.Dialer{
- Timeout: time.Second * 10,
- },
+ _, err := tls.DialWithDialer(
+ &net.Dialer{
+ Timeout: time.Second * 3,
+ },
"tcp",
- "localhost:3000",
+ "localhost:3050",
&tls.Config{
InsecureSkipVerify: true,
})
@@ -262,7 +401,7 @@ func TestStartWithTLS(t *testing.T) {
gb.Stop()
}()
- gb.Start(":3000")
+ gb.Start(":3050")
// wait for an error
err := <-errs
@@ -285,23 +424,6 @@ func TestStartInvalidListener(t *testing.T) {
}
}
-// TestStartConflictHandlers tests start with two handlers for the same path and method
-func TestStartConflictHandlers(t *testing.T) {
- gb := New()
-
- gb.Get("/test", handler)
- gb.Get("/test", handler)
-
- go func() {
- time.Sleep(1000 * time.Millisecond)
- gb.Stop()
- }()
-
- if err := gb.Start(":3001"); err == nil {
- t.Fatalf("invalid listener passed")
- }
-}
-
// TestStop tests stop service method
func TestStop(t *testing.T) {
gb := New()
@@ -315,17 +437,15 @@ func TestStop(t *testing.T) {
}
// TestRegisterFallback tests router fallback handler
-func TestRegisterFallback(t *testing.T) {
+func TestNotFound(t *testing.T) {
// get instance of gearbox
- gb := new(gearbox)
- gb.registeredRoutes = make([]*Route, 0)
- gb.settings = &Settings{}
+ gb := setupGearbox()
// register valid route
gb.Get("/ping", pingHandler)
- // register our fallback
- gb.Fallback(fallbackHandler)
+ // register not found handlers
+ gb.NotFound(fallbackHandler)
// start serving
startGearbox(gb)
@@ -368,12 +488,64 @@ func TestRegisterFallback(t *testing.T) {
}
}
-// Test Use function to try to register middlewares that work before all routes
-func Test_Use(t *testing.T) {
+// TestGroupRouting tests that you can do group routing
+func TestGroupRouting(t *testing.T) {
+ // create gearbox instance
+ gb := setupGearbox()
+ routes := []*Route{
+ gb.Get("/id", emptyHandler),
+ gb.Post("/abc", emptyHandler),
+ gb.Post("/abcd", emptyHandler),
+ }
+ gb.Group("/account", gb.Group("/api", routes))
+
+ // start serving
+ startGearbox(gb)
+
+ // One valid request, one invalid
+ testCases := []struct {
+ method string
+ path string
+ statusCode int
+ body string
+ }{
+ {method: MethodGet, path: "/account/api/id", statusCode: StatusOK},
+ {method: MethodPost, path: "/account/api/abc", statusCode: StatusOK},
+ {method: MethodPost, path: "/account/api/abcd", statusCode: StatusOK},
+ {method: MethodGet, path: "/id", statusCode: StatusNotFound, body: "Not Found"},
+ }
+
+ for _, tc := range testCases {
+ // create and make http request
+ req, _ := http.NewRequest(tc.method, tc.path, nil)
+ response, err := makeRequest(req, gb)
+
+ if err != nil {
+ t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error())
+ }
+
+ // check status code
+ if response.StatusCode != tc.statusCode {
+ t.Fatalf("%s(%s): returned %d expected %d", tc.method, tc.path, response.StatusCode, tc.statusCode)
+ }
+
+ // read body from response
+ body, err := ioutil.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("%s(%s): %s", tc.method, tc.path, err.Error())
+ }
+
+ // check response body
+ if string(body) != tc.body {
+ t.Fatalf("%s(%s): returned %s expected %s", tc.method, tc.path, body, tc.body)
+ }
+ }
+}
+
+// TestUse tries to register middlewares that work before all routes
+func TestUse(t *testing.T) {
// get instance of gearbox
- gb := new(gearbox)
- gb.registeredRoutes = make([]*Route, 0)
- gb.settings = &Settings{}
+ gb := setupGearbox()
// register valid route
gb.Get("/ping", pingHandler)
diff --git a/go.sum b/go.sum
index 17894af..c880dea 100644
--- a/go.sum
+++ b/go.sum
@@ -1,23 +1,17 @@
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
-github.com/klauspost/compress v1.10.4 h1:jFzIFaf586tquEB5EhzQG0HwGNSlgAJpG53G6Ss11wc=
-github.com/klauspost/compress v1.10.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
-github.com/valyala/fasthttp v1.14.0 h1:67bfuW9azCMwW/Jlq/C+VeihNpAuJMWkYPBig1gdi3A=
-github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0zHzaclZhE=
-github.com/valyala/fasthttp v1.15.0 h1:U2WxMim6rzae8NgjlDXNeH84c7XBBTeUcq+YXh5k+ok=
-github.com/valyala/fasthttp v1.15.0/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA=
github.com/valyala/fasthttp v1.15.1 h1:eRb5jzWhbCn/cGu3gNJMcOfPUfXgXCcQIOHjh9ajAS8=
github.com/valyala/fasthttp v1.15.1/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a h1:0R4NLDRDZX6JcmhJgXi5E4b8Wg84ihbmUKp/GvSPEzc=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y=
golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
diff --git a/router.go b/router.go
index 7a89fc4..84dedd0 100644
--- a/router.go
+++ b/router.go
@@ -1,454 +1,212 @@
package gearbox
import (
- "fmt"
- "regexp"
- "sort"
+ "log"
"strings"
"sync"
-)
-
-type routeNode struct {
- Name string
- Endpoints map[string][]*endpoint
- Children map[string]*routeNode
-}
-
-type paramType uint8
-// Supported parameter types
-const (
- ptNoParam paramType = iota // No parameter (most strict)
- ptRegexp // Regex parameter
- ptParam // Normal parameter
- ptMatchAll // Match all parameter (least strict)
+ "github.com/valyala/fasthttp"
)
-type param struct {
- Name string
- Value string
- Type paramType
- IsOptional bool
-}
-
-type endpoint struct {
- Params []*param
- Handlers handlersChain
-}
-
-type routerFallback struct {
- Handlers handlersChain
-}
-
-type matchParamsResult struct {
- Matched bool
- Handlers handlersChain
- Params map[string]string
-}
-
-// validateRoutePath makes sure that path complies with path's rules
-func validateRoutePath(path string) error {
- // Check length of the path
- length := len(path)
- if length == 0 {
- return fmt.Errorf("length is zero")
- }
-
- if length == defaultMaxRequestURLLength {
- return fmt.Errorf("length request url exceed the max limit")
- }
-
- // Make sure path starts with /
- if path[0] != '/' {
- return fmt.Errorf("path must start with /")
- }
-
- params := make(map[string]bool)
- parts := strings.Split(trimPath(path), "/")
- partsLen := len(parts)
- for i := 0; i < partsLen; i++ {
- if parts[i] == "" {
- continue
- }
- if p := parseParameter(parts[i]); p != nil {
- if p.Type == ptMatchAll && i != partsLen-1 {
- return fmt.Errorf("* must be in the end of path")
- } else if p.IsOptional && i != partsLen-1 {
- return fmt.Errorf("only last parameter can be optional")
- } else if p.Type == ptParam || p.Type == ptRegexp {
- if _, ok := params[p.Name]; ok {
- return fmt.Errorf("parameter is duplicated")
- }
- params[p.Name] = true
- }
- }
- }
-
- return nil
-}
-
-// registerRoute registers handler with method and path
-func (gb *gearbox) registerRoute(method, path string, handlers handlersChain) *Route {
-
- if gb.settings.CaseSensitive {
- path = strings.ToLower(path)
- }
-
- route := &Route{
- Path: path,
- Method: method,
- Handlers: handlers,
- }
+var (
+ defaultContentType = []byte("text/plain; charset=utf-8")
+)
- // Add route to registered routes
- gb.registeredRoutes = append(gb.registeredRoutes, route)
- return route
+type router struct {
+ trees map[string]*node
+ cache map[string]*matchResult
+ cacheLen int
+ mutex sync.RWMutex
+ notFound handlersChain
+ settings *Settings
+ pool sync.Pool
}
-// registerFallback registers a single handler that will match only if all other routes fail to match
-func (gb *gearbox) registerFallback(handlers handlersChain) error {
- // Handler is not provided
- if handlers == nil {
- return fmt.Errorf("fallback does not contain a handler")
- }
-
- gb.registeredFallback = &routerFallback{Handlers: handlers}
- return nil
+type matchResult struct {
+ handlers handlersChain
+ params map[string]string
}
-// createEmptyRouteNode creates a new route node with name
-func createEmptyRouteNode(name string) *routeNode {
- return &routeNode{
- Name: name,
- Children: make(map[string]*routeNode),
- Endpoints: make(map[string][]*endpoint),
- }
-}
-
-// parseParameter parses path part into param struct, or returns nil if it's
-// not a parameter
-func parseParameter(pathPart string) *param {
- pathPartLen := len(pathPart)
- if pathPartLen == 0 {
- return nil
- }
-
- // match all
- if pathPart[0] == '*' {
- return ¶m{
- Name: "*",
- Type: ptMatchAll,
- }
- }
-
- isOptional := pathPart[pathPartLen-1] == '?'
- if isOptional {
- pathPart = pathPart[0 : pathPartLen-1]
- }
+// acquireCtx returns instance of context after initializing it
+func (r *router) acquireCtx(fctx *fasthttp.RequestCtx) *context {
+ ctx := r.pool.Get().(*context)
- params := strings.Split(pathPart, ":")
- paramsLen := len(params)
+ // Initialize
+ ctx.index = 0
+ ctx.paramValues = make(map[string]string)
+ ctx.requestCtx = fctx
- if paramsLen == 2 && params[0] == "" { // Just a parameter
- return ¶m{
- Name: params[1],
- Type: ptParam,
- IsOptional: isOptional,
- }
- } else if paramsLen == 3 && params[0] == "" { // Regex parameter
- return ¶m{
- Name: params[1],
- Value: params[2],
- Type: ptRegexp,
- IsOptional: isOptional,
- }
- }
-
- return nil
+ return ctx
}
-// getLeastStrictParamType returns least strict parameter type from list of
-// parameters
-func getLeastStrictParamType(params []*param) paramType {
- pLen := len(params)
- if pLen == 0 {
- return ptNoParam
- }
-
- pType := params[0].Type
- for i := 1; i < pLen; i++ {
- if params[i].Type > pType {
- pType = params[i].Type
- }
- }
- return pType
+// releaseCtx frees context
+func (r *router) releaseCtx(ctx *context) {
+ ctx.handlers = nil
+ ctx.paramValues = nil
+ ctx.requestCtx = nil
+ r.pool.Put(ctx)
}
-func isValidEndpoint(endpoints []*endpoint, newEndpoint *endpoint) bool {
- for i := range endpoints {
- if len(endpoints[i].Params) == len(newEndpoint.Params) {
- isValid := false
- for j := range endpoints[i].Params {
- if endpoints[i].Params[j].Type != newEndpoint.Params[j].Type {
- isValid = true
- }
- }
- return isValid
- }
+// handle registers handlers for provided method and path to be used
+// in routing incoming requests
+func (r *router) handle(method, path string, handlers handlersChain) {
+ if path == "" {
+ panic("path is empty")
+ } else if method == "" {
+ panic("method is empty")
+ } else if path[0] != '/' {
+ panic("path must begin with '/' in path '" + path + "'")
+ } else if len(handlers) == 0 {
+ panic("no handlers provided with path '" + path + "'")
}
- return true
-}
-
-// trimPath trims left and right slashes in path
-func trimPath(path string) string {
- pathLastIndex := len(path) - 1
- for path[pathLastIndex] == '/' && pathLastIndex > 0 {
- pathLastIndex--
+ // initialize tree if it's empty
+ if r.trees == nil {
+ r.trees = make(map[string]*node)
}
- pathFirstIndex := 1
- if path[0] != '/' {
- pathFirstIndex = 0
+ // get root of method if it's existing, otherwise creates it
+ root := r.trees[method]
+ if root == nil {
+ root = createRootNode()
+ r.trees[method] = root
}
- return path[pathFirstIndex : pathLastIndex+1]
+ root.addRoute(path, handlers)
}
-// constructRoutingTree constructs routing tree according to provided routes
-func (gb *gearbox) constructRoutingTree() error {
- // Firstly, create root node
- gb.routingTreeRoot = createEmptyRouteNode("root")
-
- for _, route := range gb.registeredRoutes {
- currentNode := gb.routingTreeRoot
+// allowed checks if provided path can be routed in another method(s)
+func (r *router) allowed(reqMethod, path string, ctx *context) string {
+ var allow string
- // Handler is not provided
- if route.Handlers == nil {
- return fmt.Errorf("route %s with method %s does not contain any handlers", route.Path, route.Method)
- }
-
- // Check if path is valid or not
- if err := validateRoutePath(route.Path); err != nil {
- return fmt.Errorf("route %s is not valid! - %s", route.Path, err.Error())
- }
-
- params := make([]*param, 0)
-
- // Split path into slices of parts
- parts := strings.Split(route.Path, "/")
-
- partsLen := len(parts)
- for i := 1; i < partsLen; i++ {
- part := parts[i]
-
- // Do not create node if part is empty
- if part == "" {
+ pathLen := len(path)
+ if (pathLen == 1 && path[0] == '*') || (pathLen > 1 && path[1] == '*') {
+ for method := range r.trees {
+ if method == MethodOptions {
continue
}
- // Parse part as a parameter if it is
- if param := parseParameter(part); param != nil {
- params = append(params, param)
- continue
+ if allow != "" {
+ allow += ", " + method
+ } else {
+ allow = method
}
-
- // Try to get a child of current node with part, otherwise
- //creates a new node and make it current node
- partNode, ok := currentNode.Children[part]
- if !ok {
- partNode = createEmptyRouteNode(part)
- currentNode.Children[part] = partNode
- }
- currentNode = partNode
- }
-
- currentEndpoint := &endpoint{
- Handlers: route.Handlers,
- Params: params,
}
-
- // Make sure that current node does not have a handler for route's method
- var endpoints []*endpoint
- if result, ok := currentNode.Endpoints[route.Method]; ok {
- if ok := isValidEndpoint(result, currentEndpoint); !ok {
- return fmt.Errorf("there already registered method %s for %s", route.Method, route.Path)
- }
-
- endpoints = append(result, currentEndpoint)
- sort.Slice(endpoints, func(i, j int) bool {
- iLen := len(endpoints[i].Params)
- jLen := len(endpoints[j].Params)
- if iLen == jLen {
- iParamType := getLeastStrictParamType(endpoints[i].Params)
- jParamType := getLeastStrictParamType(endpoints[j].Params)
- return iParamType < jParamType
- }
-
- return iLen > jLen
- })
- } else {
- endpoints = []*endpoint{currentEndpoint}
- }
-
- // Save handler to route's method for current node
- currentNode.Endpoints[route.Method] = endpoints
- }
- return nil
-}
-
-// matchRoute matches provided method and path with handler if it's existing
-func (gb *gearbox) matchRoute(method, path string) (handlersChain, map[string]string) {
- if handlers, params := gb.matchRouteAgainstRegistered(method, path); handlers != nil {
- return handlers, params
+ return allow
}
- if gb.registeredFallback != nil && gb.registeredFallback.Handlers != nil {
- return gb.registeredFallback.Handlers, make(map[string]string)
- }
-
- return nil, nil
-}
-
-func matchEndpointParams(ep *endpoint, paths []string, pathIndex int) (map[string]string, bool) {
- endpointParams := ep.Params
- endpointParamsLen := len(endpointParams)
- pathsLen := len(paths)
- paramDic := make(map[string]string, endpointParamsLen)
-
- paramIdx := 0
- for paramIdx < endpointParamsLen {
- if endpointParams[paramIdx].Type == ptMatchAll {
- // Last parameter, so we can return
- return paramDic, true
- }
-
- // path has ended and there is more parameters to match
- if pathIndex >= pathsLen {
- // If it's optional means this is the last parameter.
- if endpointParams[paramIdx].IsOptional {
- return paramDic, true
- }
-
- return nil, false
- }
-
- if paths[pathIndex] == "" {
- pathIndex++
+ for method, tree := range r.trees {
+ if method == reqMethod || method == MethodOptions {
continue
}
- if endpointParams[paramIdx].Type == ptParam {
- paramDic[endpointParams[paramIdx].Name] = paths[pathIndex]
- } else if endpointParams[paramIdx].Type == ptRegexp {
- if match, _ := regexp.MatchString(endpointParams[paramIdx].Value, paths[pathIndex]); match {
- paramDic[endpointParams[paramIdx].Name] = paths[pathIndex]
- } else if !endpointParams[paramIdx].IsOptional {
- return nil, false
+ handlers := tree.matchRoute(path, ctx)
+ if handlers != nil {
+ if allow != "" {
+ allow += ", " + method
+ } else {
+ allow = method
}
}
-
- paramIdx++
- pathIndex++
}
- for pathIndex < pathsLen && paths[pathIndex] == "" {
- pathIndex++
+ if len(allow) > 0 {
+ allow += ", " + MethodOptions
}
-
- // There is more parts, so no match
- if pathsLen-pathIndex > 0 {
- return nil, false
- }
-
- return paramDic, true
+ return allow
}
-func matchNodeEndpoints(node *routeNode, method string, paths []string,
- pathIndex int, result *matchParamsResult, wg *sync.WaitGroup) {
- if endpoints, ok := node.Endpoints[method]; ok {
- for j := range endpoints {
- if params, matched := matchEndpointParams(endpoints[j], paths, pathIndex); matched {
- result.Matched = true
- result.Params = params
- result.Handlers = endpoints[j].Handlers
- wg.Done()
- return
+// Handler handles all incoming requests
+func (r *router) Handler(fctx *fasthttp.RequestCtx) {
+ context := r.acquireCtx(fctx)
+ defer r.releaseCtx(context)
+
+ if r.settings.AutoRecover {
+ defer func(fctx *fasthttp.RequestCtx) {
+ if rcv := recover(); rcv != nil {
+ log.Printf("recovered from error: %v", rcv)
+ fctx.Error(fasthttp.StatusMessage(fasthttp.StatusInternalServerError),
+ fasthttp.StatusInternalServerError)
}
- }
+ }(fctx)
}
- result.Matched = false
- wg.Done()
-}
-
-func (gb *gearbox) matchRouteAgainstRegistered(method, path string) (handlersChain, map[string]string) {
- // Start with root node
- currentNode := gb.routingTreeRoot
+ path := GetString(fctx.URI().PathOriginal())
- // Return if root is empty, or path is not valid
- if currentNode == nil || path == "" || path[0] != '/' || len(path) > defaultMaxRequestURLLength {
- return nil, nil
- }
-
- if gb.settings.CaseSensitive {
+ if r.settings.CaseInSensitive {
path = strings.ToLower(path)
}
- trimmedPath := trimPath(path)
+ method := GetString(fctx.Method())
+
+ var cacheKey string
+ useCache := !r.settings.DisableCaching &&
+ (method == MethodGet || method == MethodPost)
+ if useCache {
+ cacheKey = path + method
+ r.mutex.RLock()
+ cacheResult, ok := r.cache[cacheKey]
- // Try to get from cache if it's enabled
- cacheKey := ""
- if !gb.settings.DisableCaching {
- cacheKey = method + trimmedPath
- if cacheResult, ok := gb.cache.Get(cacheKey).(*matchParamsResult); ok {
- return cacheResult.Handlers, cacheResult.Params
+ if ok {
+ context.handlers = cacheResult.handlers
+ context.paramValues = cacheResult.params
+ r.mutex.RUnlock()
+ context.handlers[0](context)
+ return
}
+ r.mutex.RUnlock()
}
- paths := strings.Split(trimmedPath, "/")
+ if root := r.trees[method]; root != nil {
+ if handlers := root.matchRoute(path, context); handlers != nil {
+ context.handlers = handlers
+ context.handlers[0](context)
- var wg sync.WaitGroup
- lastMatchedNodes := []*matchParamsResult{{}}
- lastMatchedNodesIndex := 1
- wg.Add(1)
- go matchNodeEndpoints(currentNode, method, paths, 0, lastMatchedNodes[0], &wg)
+ if useCache {
+ r.mutex.Lock()
- for i := range paths {
- if paths[i] == "" {
- continue
+ if r.cacheLen == r.settings.CacheSize {
+ r.cache = make(map[string]*matchResult)
+ r.cacheLen = 0
+ }
+ r.cache[cacheKey] = &matchResult{
+ handlers: handlers,
+ params: context.paramValues,
+ }
+ r.cacheLen++
+ r.mutex.Unlock()
+ }
+ return
}
+ }
- // Try to match part with a child of current node
- pathNode, ok := currentNode.Children[paths[i]]
- if !ok {
- break
+ if method == MethodOptions && r.settings.HandleOPTIONS {
+ if allow := r.allowed(method, path, context); len(allow) > 0 {
+ fctx.Response.Header.Set("Allow", allow)
+ return
+ }
+ } else if r.settings.HandleMethodNotAllowed {
+ if allow := r.allowed(method, path, context); len(allow) > 0 {
+ fctx.Response.Header.Set("Allow", allow)
+ fctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
+ fctx.SetContentTypeBytes(defaultContentType)
+ fctx.SetBodyString(fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed))
+ return
}
-
- // set matched node as current node
- currentNode = pathNode
-
- v := matchParamsResult{}
- lastMatchedNodes = append(lastMatchedNodes, &v)
- wg.Add(1)
- go matchNodeEndpoints(currentNode, method, paths, i+1, &v, &wg)
- lastMatchedNodesIndex++
}
- wg.Wait()
-
- // Return longest prefix match
- for i := lastMatchedNodesIndex - 1; i >= 0; i-- {
- if lastMatchedNodes[i].Matched {
- if !gb.settings.DisableCaching {
- go func(key string, matchResult *matchParamsResult) {
- gb.cache.Set(key, matchResult)
- }(string(cacheKey), lastMatchedNodes[i])
- }
-
- return lastMatchedNodes[i].Handlers, lastMatchedNodes[i].Params
- }
+ // Custom Not Found (404) handlers
+ if r.notFound != nil {
+ r.notFound[0](context)
+ return
}
- return nil, nil
+ // Default Not Found response
+ fctx.Error(fasthttp.StatusMessage(fasthttp.StatusNotFound),
+ fasthttp.StatusNotFound)
+}
+
+// SetNotFound appends handlers to custom not found (404) handlers
+func (r *router) SetNotFound(handlers handlersChain) {
+ r.notFound = append(r.notFound, handlers...)
}
diff --git a/router_test.go b/router_test.go
index 66f8673..75037b6 100644
--- a/router_test.go
+++ b/router_test.go
@@ -1,493 +1,74 @@
package gearbox
import (
- "fmt"
+ "sync"
"testing"
)
-// setupGearbox returns instace of gearbox struct
-func setupGearbox(settings ...*Settings) *gearbox {
- gb := new(gearbox)
- gb.registeredRoutes = make([]*Route, 0)
-
- if len(settings) > 0 {
- gb.settings = settings[0]
- } else {
- gb.settings = &Settings{}
- }
-
- gb.cache = NewCache(defaultCacheSize)
- return gb
-}
-
-// TestValidateRoutePath tests if provided paths are valid or not
-func TestValidateRoutePath(t *testing.T) {
- // test cases
- tests := []struct {
- input string
- isErr bool
- }{
- {input: "", isErr: true},
- {input: "user", isErr: true},
- {input: "/user", isErr: false},
- {input: "/admin/", isErr: false},
- {input: "/user/*/get", isErr: true},
- {input: "/user/*", isErr: false},
- {input: "/user/:name", isErr: false},
- {input: "/user/:name/:name", isErr: true},
- {input: "/user/:name?/get", isErr: true},
- }
-
- for _, tt := range tests {
- err := validateRoutePath(tt.input)
- if (err != nil && !tt.isErr) || (err == nil && tt.isErr) {
- errMsg := ""
-
- // get error message if there is
- if err != nil {
- errMsg = err.Error()
- }
-
- t.Errorf("input %s find error %t %s expecting error %t", tt.input, err == nil, errMsg, tt.isErr)
- }
- }
-}
-
-// TestCreateEmptyNode tests creating route node with specific name
-func TestCreateEmptyNode(t *testing.T) {
- name := "test_node"
- node := createEmptyRouteNode(name)
-
- if node == nil {
- // node.Name != name {
- t.Errorf("find name %s expecting name %s", node.Name, name)
- }
-}
-
-// emptyHandler just an empty handler
-var emptyHandler = func(ctx *Context) {}
-
-// empty Handlers chain is just an empty array
-var emptyHandlersChain = handlersChain{}
-
-// TestRegisterRoute tests registering routes after validating it
-func TestRegisterRoute(t *testing.T) {
- // test cases
- tests := []struct {
- method string
- path string
- handler handlersChain
- isErr bool
- }{
- {method: MethodPut, path: "/admin/welcome", handler: emptyHandlersChain, isErr: false},
- {method: MethodPost, path: "/user/add", handler: emptyHandlersChain, isErr: false},
- {method: MethodGet, path: "/account/get", handler: emptyHandlersChain, isErr: false},
- {method: MethodGet, path: "/account/*", handler: emptyHandlersChain, isErr: false},
- {method: MethodGet, path: "/account/*", handler: emptyHandlersChain, isErr: false},
- {method: MethodDelete, path: "/account/delete", handler: emptyHandlersChain, isErr: false},
- {method: MethodDelete, path: "/account/delete", handler: nil, isErr: true},
- {method: MethodGet, path: "/account/*/getAccount", handler: nil, isErr: true},
- {method: MethodGet, path: "/books/:name/:test", handler: emptyHandlersChain, isErr: false},
- {method: MethodGet, path: "/books/:name/:name", handler: nil, isErr: true},
- }
-
- // counter for valid routes
- validCounter := 0
-
- for _, tt := range tests {
- // create gearbox instance so old errors don't affect new routes
- gb := setupGearbox()
-
- gb.registerRoute(tt.method, tt.path, tt.handler)
- err := gb.constructRoutingTree()
- if (err != nil && !tt.isErr) || (err == nil && tt.isErr) {
- errMsg := ""
-
- // get error message if there is
- if err != nil {
- errMsg = err.Error()
- }
-
- t.Errorf("input %v find error %t %s expecting error %t", tt, err == nil, errMsg, tt.isErr)
- }
-
- if !tt.isErr {
- validCounter++
- }
- }
-}
-
-// TestRegisterInvalidRoute tests registering invalid routes
-func TestRegisterInvalidRoute(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
-
- // test handler is nil
- gb.registerRoute(MethodGet, "invalid Path", emptyHandlersChain)
-
- if err := gb.constructRoutingTree(); err == nil {
- t.Errorf("input GET invalid Path find nil expecting error")
- }
-}
-
-// TestParseParameter tests parsing parameters into param struct
-func TestParseParameter(t *testing.T) {
- tests := []struct {
- path string
- output *param
- }{
- {path: ":test", output: ¶m{Name: "test", Type: ptParam}},
- {path: ":test2:[a-z]", output: ¶m{Name: "test2", Value: "[a-z]", Type: ptRegexp}},
- {path: "*", output: ¶m{Name: "*", Type: ptMatchAll}},
- {path: "user:[a-z]", output: nil},
- {path: "user", output: nil},
- {path: "", output: nil},
- }
-
- for _, test := range tests {
- p := parseParameter(test.path)
- if test.output == nil && p == nil {
- continue
- }
- if (test.output == nil && p != nil) ||
- (test.output != nil && p == nil) ||
- (test.output.Name != p.Name) ||
- (test.output.Type != p.Type) ||
- (test.output.Value != p.Value) {
- t.Errorf("path %s, find %v expected %v", test.path, p, test.output)
- }
- }
-}
-
-// TestGetLeastStrictParamType test
-func TestGetLeastStrictParamType(t *testing.T) {
- tests := []struct {
- params []*param
- output paramType
- }{
- {params: []*param{}, output: ptNoParam},
- {params: []*param{
- {Type: ptParam, Name: "name"},
- {Type: ptRegexp, Name: "test", Value: "[a-z]"},
- {Type: ptMatchAll, Name: "*"},
- }, output: ptMatchAll},
- {params: []*param{
- {Type: ptParam, Name: "name"},
- {Type: ptMatchAll, Name: "*"},
- }, output: ptMatchAll},
- {params: []*param{
- {Type: ptParam, Name: "name"},
- {Type: ptRegexp, Name: "test3", Value: "[a-z]"},
- }, output: ptParam},
- {params: []*param{
- {Type: ptRegexp, Name: "test3", Value: "[a-z]"},
- {Type: ptMatchAll, Name: "*"},
- }, output: ptMatchAll},
- }
-
- for _, test := range tests {
- paramType := getLeastStrictParamType(test.params)
- if paramType != test.output {
- t.Errorf("params %v, find %d expected %d", test.params, paramType, test.output)
- }
- }
-}
-
-// TestTrimPath test
-func TestTrimPath(t *testing.T) {
- tests := []struct {
- input string
- output string
- }{
- {input: "/", output: ""},
- {input: "/test/", output: "test"},
- {input: "test2/", output: "test2"},
- {input: "test2", output: "test2"},
- {input: "/user/test", output: "user/test"},
- {input: "/books/get/test/", output: "books/get/test"},
- }
-
- for _, test := range tests {
- trimmedPath := trimPath(test.input)
- if trimmedPath != test.output {
- t.Errorf("path %s, find %s expected %s", test.input, trimmedPath, test.output)
- }
- }
-}
-
-// TestIsValidEndpoint test
-func TestIsValidEndpoint(t *testing.T) {
- tests := []struct {
- endpoints []*endpoint
- newEndpoint *endpoint
- output bool
- }{
- {endpoints: []*endpoint{}, newEndpoint: &endpoint{}, output: true},
- {endpoints: []*endpoint{
- {Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "user", Type: ptParam},
- {Name: "name", Type: ptParam},
- }},
- {Handlers: handlersChain{emptyHandler}, Params: []*param{}},
- }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "test", Type: ptParam},
- }}, output: true},
- {endpoints: []*endpoint{}, newEndpoint: &endpoint{}, output: true},
- {endpoints: []*endpoint{
- {Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "user", Type: ptParam},
- }},
- {Handlers: handlersChain{emptyHandler}, Params: []*param{}},
- }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "test", Type: ptParam},
- }}, output: false},
- {endpoints: []*endpoint{
- {Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "user", Type: ptRegexp, Value: "[a-z]"},
- }},
- {Handlers: handlersChain{emptyHandler}, Params: []*param{}},
- }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "test", Type: ptParam},
- }}, output: true},
- {endpoints: []*endpoint{
- {Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "*", Type: ptMatchAll},
- }},
- {Handlers: handlersChain{emptyHandler}, Params: []*param{}},
- }, newEndpoint: &endpoint{Handlers: handlersChain{emptyHandler}, Params: []*param{
- {Name: "test", Type: ptRegexp, Value: "[a-z]"},
- }}, output: true},
- }
-
- for _, test := range tests {
- isValid := isValidEndpoint(test.endpoints, test.newEndpoint)
- if isValid != test.output {
- t.Errorf("endpoints %v, new endpoint %v find %t expected %t", test.endpoints,
- test.newEndpoint, isValid, test.output)
- }
- }
-}
-
-// TestConstructRoutingTree tests constructing routing tree and matching routes properly
-func TestConstructRoutingTree(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox(&Settings{
- CacheSize: 1,
- })
-
+func TestHandle(t *testing.T) {
// testing routes
routes := []struct {
- method string
- path string
- handler handlersChain
+ method string
+ path string
+ conflict bool
+ handlers handlersChain
}{
- {method: MethodGet, path: "/articles/search", handler: emptyHandlersChain},
- {method: MethodGet, path: "/articles/test", handler: emptyHandlersChain},
- {method: MethodGet, path: "/articles/204", handler: emptyHandlersChain},
- {method: MethodGet, path: "/posts", handler: emptyHandlersChain},
- {method: MethodGet, path: "/post/502", handler: emptyHandlersChain},
- {method: MethodGet, path: "/post/a23011a", handler: emptyHandlersChain},
- {method: MethodGet, path: "/user/204", handler: emptyHandlersChain},
- {method: MethodGet, path: "/user/205/", handler: emptyHandlersChain},
- {method: MethodPost, path: "/user/204/setting", handler: emptyHandlersChain},
- {method: MethodGet, path: "/users/*", handler: emptyHandlersChain},
- {method: MethodGet, path: "/books/get/:name", handler: emptyHandlersChain},
- {method: MethodGet, path: "/books/get/*", handler: emptyHandlersChain},
- {method: MethodGet, path: "/books/search/:pattern:([a-z]+", handler: emptyHandlersChain},
- {method: MethodGet, path: "/books/search/:pattern", handler: emptyHandlersChain},
- {method: MethodGet, path: "/books/search/:pattern1/:pattern2/:pattern3", handler: emptyHandlersChain},
- {method: MethodGet, path: "/books//search/*", handler: emptyHandlersChain},
- {method: MethodGet, path: "/account/:name?", handler: emptyHandlersChain},
- {method: MethodGet, path: "/profile/:name:([a-z]+)?", handler: emptyHandlersChain},
- {method: MethodGet, path: "/order/:name1/:name2:([a-z]+)?", handler: emptyHandlersChain},
- {method: MethodGet, path: "/", handler: emptyHandlersChain},
- }
-
- // register routes
- for _, r := range routes {
- gb.registerRoute(r.method, r.path, r.handler)
- }
-
- gb.constructRoutingTree()
-
- // requests test cases
- requests := []struct {
- method string
- path string
- params map[string]string
- match bool
- }{
- {method: MethodPut, path: "/admin/welcome", match: false, params: make(map[string]string)},
- {method: MethodGet, path: "/articles/search", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/articles/test", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/articles/test", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/articles/test", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/articles/204", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/posts", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/post/502", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/post/a23011a", match: true, params: make(map[string]string)},
- {method: MethodPost, path: "/post/a23011a", match: false, params: make(map[string]string)},
- {method: MethodGet, path: "/user/204", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/user/205", match: true, params: make(map[string]string)},
- {method: MethodPost, path: "/user/204/setting", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/users/ahmed", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/users/ahmed/ahmed", match: true, params: make(map[string]string)},
- {method: MethodPut, path: "/users/ahmed/ahmed", match: false, params: make(map[string]string)},
- {method: MethodGet, path: "/books/get/test", match: true, params: map[string]string{"name": "test"}},
- {method: MethodGet, path: "/books/search/test", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/books/search//test", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/books/search/123", match: true, params: map[string]string{"pattern": "123"}},
- {method: MethodGet, path: "/books/search/test1/test2/test3", match: true, params: map[string]string{"pattern1": "test1", "pattern2": "test2", "pattern3": "test3"}},
- {method: MethodGet, path: "/books/search/test/test2", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/books/search/test/test2", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/account/testuser", match: true, params: map[string]string{"name": "testuser"}},
- {method: MethodGet, path: "/account", match: true, params: make(map[string]string)},
- {method: MethodPut, path: "/account/test1/test2", match: false, params: make(map[string]string)},
- {method: MethodGet, path: "/profile/testuser", match: true, params: map[string]string{"name": "testuser"}},
- {method: MethodGet, path: "/profile", match: true, params: make(map[string]string)},
- {method: MethodGet, path: "/order/test1", match: true, params: map[string]string{"name1": "test1"}},
- {method: MethodGet, path: "/order/test1/test2/", match: true, params: map[string]string{"name1": "test1", "name2": "test2"}},
- {method: MethodPut, path: "/order/test1/test2/test3", match: false, params: make(map[string]string)},
- {method: MethodGet, path: "/", match: true, params: make(map[string]string)},
- }
-
- // test matching routes
- for _, rq := range requests {
- handler, params := gb.matchRoute(rq.method, rq.path)
- if (handler != nil && !rq.match) || (handler == nil && rq.match) {
- t.Errorf("input %s %s find nil expecting handler", rq.method, rq.path)
- }
- for paramKey, expectedParam := range rq.params {
- if actualParam, ok := params[paramKey]; !ok || actualParam != expectedParam {
- if !ok {
- actualParam = "nil"
- }
- for k, w := range params {
- fmt.Println(k, string(w))
- }
-
- t.Errorf("input %s %s parameter %s find %s expecting %s",
- rq.method, rq.path, paramKey, actualParam, expectedParam)
+ {method: MethodGet, path: "/articles/search", conflict: false, handlers: fakeHandlersChain},
+ {method: MethodGet, path: "/articles/test", conflict: false, handlers: fakeHandlersChain},
+ {method: MethodGet, path: "", conflict: true, handlers: fakeHandlersChain},
+ {method: "", path: "/articles/test", conflict: true, handlers: fakeHandlersChain},
+ {method: MethodGet, path: "orders/test", conflict: true, handlers: fakeHandlersChain},
+ {method: MethodGet, path: "/books/test", conflict: true, handlers: emptyHandlersChain},
+ }
+
+ router := &router{
+ settings: &Settings{},
+ cache: make(map[string]*matchResult),
+ pool: sync.Pool{
+ New: func() interface{} {
+ return new(context)
+ },
+ },
+ }
+
+ for _, route := range routes {
+ recv := catchPanic(func() {
+ router.handle(route.method, route.path, route.handlers)
+ })
+
+ if route.conflict {
+ if recv == nil {
+ t.Errorf("no panic for conflicting route '%s'", route.path)
}
+ } else if recv != nil {
+ t.Errorf("unexpected panic for route '%s': %v", route.path, recv)
}
}
-}
-// TestNullRoutingTree tests matching with null routing tree
-func TestNullRoutingTree(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
-
- // register route
- gb.registerRoute(MethodGet, "/*", emptyHandlersChain)
-
- // test handler is nil
- if handler, _ := gb.matchRoute(MethodGet, "/hello/world"); handler != nil {
- t.Errorf("input GET /hello/world find handler expecting nil")
- }
}
-// TestMatchAll tests matching all requests with one handler
-func TestMatchAll(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
-
- // register route
- gb.registerRoute(MethodGet, "/*", emptyHandlersChain)
- gb.constructRoutingTree()
-
- // test handler is not nil
- if handler, _ := gb.matchRoute(MethodGet, "/hello/world"); handler == nil {
- t.Errorf("input GET /hello/world find nil expecting handler")
- }
-
- if handler, _ := gb.matchRoute(MethodGet, "//world"); handler == nil {
- t.Errorf("input GET //world find nil expecting handler")
- }
-}
-
-// TestConstructRoutingTree tests constructing routing tree with two handlers
-// for the same path and method
-func TestConstructRoutingTreeConflict(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
-
- // register routes
- gb.registerRoute(MethodGet, "/articles/test", emptyHandlersChain)
- gb.registerRoute(MethodGet, "/articles/test", emptyHandlersChain)
-
- if err := gb.constructRoutingTree(); err == nil {
- t.Fatalf("invalid listener passed")
- }
-}
-
-// TestNoRegisteredFallback tests that if no registered fallback is available
-// matchRoute() returns nil
-func TestNoRegisteredFallback(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
-
- // register routes
- gb.registerRoute(MethodGet, "/articles", emptyHandlersChain)
- gb.constructRoutingTree()
-
- // attempt to match route that cannot match
- if handler, _ := gb.matchRoute(MethodGet, "/fail"); handler != nil {
- t.Errorf("input GET /fail found a valid handler, expecting nil")
- }
-}
-
-// TestFallback tests that if a registered fallback is available
-// matchRoute() returns the non-nil registered fallback handler
-func TestFallback(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
-
- // register routes
- gb.registerRoute(MethodGet, "/articles", emptyHandlersChain)
- if err := gb.registerFallback(emptyHandlersChain); err != nil {
- t.Errorf("invalid fallback: %s", err.Error())
- }
- gb.constructRoutingTree()
-
- // attempt to match route that cannot match
- if handler, _ := gb.matchRoute(MethodGet, "/fail"); handler == nil {
- t.Errorf("input GET /fail did not find a valid handler, expecting valid fallback handler")
+func TestHandler(t *testing.T) {
+ routes := []struct {
+ method string
+ path string
+ handlers handlersChain
+ }{
+ {method: MethodGet, path: "/articles/search", handlers: fakeHandlersChain},
+ {method: MethodGet, path: "/articles/test", handlers: fakeHandlersChain},
}
-}
-
-// TestInvalidFallback tests that a fallback cannot be registered
-// with a nil handler
-func TestInvalidFallback(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
- // attempt to register an invalid (nil) fallback handler
- if err := gb.registerFallback(nil); err == nil {
- t.Errorf("registering an invalid fallback did not return an error, expected error")
+ router := &router{
+ settings: &Settings{},
+ cache: make(map[string]*matchResult),
+ pool: sync.Pool{
+ New: func() interface{} {
+ return new(context)
+ },
+ },
}
-}
-// TestGroupRouting tests that you can do group routing
-func TestGroupRouting(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
- routes := []*Route{gb.Get("/id", emptyHandler), gb.Post("/abc", emptyHandler), gb.Post("/abcd", emptyHandler)}
- gb.Group("/account", routes)
- // attempt to register an invalid (nil) fallback handler
- if err := gb.constructRoutingTree(); err != nil {
- t.Errorf("Grout routing failed, error: %v", err)
+ for _, route := range routes {
+ router.handle(route.method, route.path, route.handlers)
}
-}
-// TestNestedGroupRouting tests that you can do group routing inside a group routing
-func TestNestedGroupRouting(t *testing.T) {
- // create gearbox instance
- gb := setupGearbox()
- routes := []*Route{gb.Get("/id", emptyHandler), gb.Post("/abc", emptyHandler), gb.Post("/abcd", emptyHandler)}
- gb.Group("/account", gb.Group("/api", routes))
- // attempt to register an invalid (nil) fallback handler
- if err := gb.constructRoutingTree(); err != nil {
- t.Errorf("Grout routing failed, error: %v", err)
- }
}
diff --git a/tree.go b/tree.go
new file mode 100644
index 0000000..bf7fba7
--- /dev/null
+++ b/tree.go
@@ -0,0 +1,168 @@
+package gearbox
+
+import (
+ "strings"
+)
+
+type nodeType uint8
+
+const (
+ static nodeType = iota
+ root
+ parama
+ catchAll
+)
+
+type node struct {
+ path string
+ param *node
+ children map[string]*node
+ nType nodeType
+ handlers handlersChain
+}
+
+func (n *node) addRoute(path string, handlers handlersChain) {
+ currentNode := n
+ originalPath := path
+ path = path[1:]
+
+ paramNames := make(map[string]bool)
+
+ for {
+ pathLen := len(path)
+ if pathLen == 0 {
+ if currentNode.handlers != nil {
+ panic("handlers are already registered for path '" + originalPath + "'")
+ }
+ currentNode.handlers = handlers
+ break
+ }
+
+ segmentDelimiter := strings.Index(path, "/")
+ if segmentDelimiter == -1 {
+ segmentDelimiter = pathLen
+ }
+
+ pathSegment := path[:segmentDelimiter]
+ if pathSegment[0] == ':' || pathSegment[0] == '*' {
+ // Parameter
+ if len(currentNode.children) > 0 {
+ panic("parameter " + pathSegment +
+ " conflicts with existing static children in path '" +
+ originalPath + "'")
+ }
+
+ if currentNode.param != nil {
+ if currentNode.param.path[0] == '*' {
+ panic("parameter " + pathSegment +
+ " conflicts with catch all (*) route in path '" +
+ originalPath + "'")
+ } else if currentNode.param.path != pathSegment {
+ panic("parameter " + pathSegment + " in new path '" +
+ originalPath + "' conflicts with existing wildcard '" +
+ currentNode.param.path)
+ }
+ }
+
+ if currentNode.param == nil {
+ var nType nodeType
+ if pathSegment[0] == '*' {
+ nType = catchAll
+ if pathLen > 1 {
+ panic("catch all (*) routes are only allowed " +
+ "at the end of the path in path '" +
+ originalPath + "'")
+ }
+ } else {
+ nType = parama
+ if _, ok := paramNames[pathSegment]; ok {
+ panic("parameter " + pathSegment +
+ " must be unique in path '" + originalPath + "'")
+ } else {
+ paramNames[pathSegment] = true
+ }
+ }
+
+ currentNode.param = &node{
+ path: pathSegment,
+ nType: nType,
+ children: make(map[string]*node),
+ }
+ }
+ currentNode = currentNode.param
+ } else {
+ // Static
+ if currentNode.param != nil {
+ panic(pathSegment + "' conflicts with existing parameter " +
+ currentNode.param.path + " in path '" + originalPath + "'")
+ }
+ if child, ok := currentNode.children[pathSegment]; ok {
+ currentNode = child
+
+ } else {
+ child = &node{
+ path: pathSegment,
+ nType: static,
+ children: make(map[string]*node),
+ }
+ currentNode.children[pathSegment] = child
+ currentNode = child
+ }
+ }
+
+ if pathLen > segmentDelimiter {
+ segmentDelimiter++
+ }
+ path = path[segmentDelimiter:]
+ }
+}
+
+func (n *node) matchRoute(path string, ctx *context) handlersChain {
+ pathLen := len(path)
+ if pathLen > 0 && path[0] != '/' {
+ return nil
+ }
+
+ currentNode := n
+ path = path[1:]
+
+ for {
+ pathLen = len(path)
+
+ if pathLen == 0 || currentNode.nType == catchAll {
+ return currentNode.handlers
+ }
+ segmentDelimiter := strings.Index(path, "/")
+ if segmentDelimiter == -1 {
+ segmentDelimiter = pathLen
+ }
+ pathSegment := path[:segmentDelimiter]
+
+ if pathLen > segmentDelimiter {
+ segmentDelimiter++
+ }
+ path = path[segmentDelimiter:]
+
+ if currentNode.param != nil {
+ currentNode = currentNode.param
+ ctx.paramValues[currentNode.path[1:]] = pathSegment
+ continue
+ }
+
+ if child, ok := currentNode.children[pathSegment]; ok {
+ currentNode = child
+ continue
+ }
+
+ return nil
+ }
+}
+
+// createRootNode creates an instance of node with root type
+func createRootNode() *node {
+ return &node{
+ nType: root,
+ path: "/",
+ children: make(map[string]*node),
+ }
+}
diff --git a/tree_test.go b/tree_test.go
new file mode 100644
index 0000000..5aa6a49
--- /dev/null
+++ b/tree_test.go
@@ -0,0 +1,133 @@
+package gearbox
+
+import "testing"
+
+func catchPanic(f func()) (recv interface{}) {
+ defer func() {
+ recv = recover()
+ }()
+
+ f()
+ return
+}
+
+type testRoute struct {
+ path string
+ conflict bool
+}
+
+func TestAddRoute(t *testing.T) {
+ tree := createRootNode()
+
+ routes := []testRoute{
+ {"/cmd/:tool/:sub", false},
+ {"/cmd/vet", true},
+ {"/src/*", false},
+ {"/src/*", true},
+ {"/src/test", true},
+ {"/src/:test", true},
+ {"/src/", false},
+ {"/src1/", false},
+ {"/src1/*", false},
+ {"/search/:query", false},
+ {"/search/invalid", true},
+ {"/user_:name", false},
+ {"/user_x", false},
+ {"/id:id", false},
+ {"/id/:id", false},
+ {"/id/:value", true},
+ {"/id/:id/settings", false},
+ {"/id/:id/:type", true},
+ {"/*", true},
+ {"books/*/get", true},
+ {"/file/test", false},
+ {"/file/test", true},
+ {"/file/:test", true},
+ {"/orders/:id/settings/:id", true},
+ {"/accounts/*/settings", true},
+ {"/results/*", false},
+ {"/results/*/view", true},
+ }
+ for _, route := range routes {
+ recv := catchPanic(func() {
+ tree.addRoute(route.path, emptyHandlersChain)
+ })
+
+ if route.conflict {
+ if recv == nil {
+ t.Errorf("no panic for conflicting route '%s'", route.path)
+ }
+ } else if recv != nil {
+ t.Errorf("unexpected panic for route '%s': %v", route.path, recv)
+ }
+ }
+}
+
+type testRequests []struct {
+ path string
+ match bool
+ params map[string]string
+}
+
+func TestMatchRoute(t *testing.T) {
+ tree := createRootNode()
+
+ routes := [...]string{
+ "/hi",
+ "/contact",
+ "/users/:id/",
+ "/books/*",
+ "/search/:item1/settings/:item2",
+ "/co",
+ "/c",
+ "/a",
+ "/ab",
+ "/doc/",
+ "/doc/go_faq.html",
+ "/doc/go1.html",
+ "/α",
+ "/β",
+ }
+ for _, route := range routes {
+ tree.addRoute(route, emptyHandlersChain)
+ }
+
+ requests := testRequests{
+ {"/a", true, nil},
+ {"/", false, nil},
+ {"/hi", true, nil},
+ {"/contact", true, nil},
+ {"/co", true, nil},
+ {"/con", false, nil}, // key mismatch
+ {"/cona", false, nil}, // key mismatch
+ {"/no", false, nil}, // no matching child
+ {"/ab", true, nil},
+ {"/α", true, nil},
+ {"/β", true, nil},
+ {"/users/test", true, map[string]string{"id": "test"}},
+ {"/books/title", true, nil},
+ {"/search/test1/settings/test2", true, map[string]string{"item1": "test1", "item2": "test2"}},
+ {"/search/test1", false, nil},
+ {"test", false, nil},
+ }
+ for _, request := range requests {
+ ctx := &context{paramValues: make(map[string]string)}
+ handler := tree.matchRoute(request.path, ctx)
+
+ if handler == nil {
+ if request.match {
+ t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path)
+ }
+ } else if !request.match {
+ t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path)
+ }
+
+ for expectedKey, expectedValue := range request.params {
+ actualValue := ctx.Param(expectedKey)
+ if actualValue != expectedValue {
+ t.Errorf(" mismatch for route '%s' parameter '%s' actual '%s', expected '%s'",
+ request.path, expectedKey, actualValue, expectedValue)
+ }
+ }
+ }
+}