中间件的作用是给应用添加一些额外的功能,但是并不会影响原有应用的编码方式,想用的时候直接添加,不想用可以很轻松的去除,做到所谓的可插拔。
中间件的实现位置在哪里?
位置:在处理器的前后
注意:中间件是一个调用链条,所以在处理真正的业务之前,可能会经过多个中间件
type MiddlewareFunc func(handleFunc HandleFunc)HandleFunc
中间件的执行,定义为组级别。
type routerGroup struct {
name string
handleFuncMap map[string]map[string]HandleFunc
handlerMethodMap map[string][]string
treeNode *treeNode
preMiddlewares []MiddlewareFunc //前置中间件
postMiddlewares []MiddlewareFunc //后置中间件
}
func (r *routerGroup) PreHandle(middlewareFunc ...MiddlewareFunc) {
r.preMiddlewares = append(r.preMiddlewares, middlewareFunc...)
}
func (r *routerGroup) methodHandle(h HandleFunc, ctx *Context) {
//前置中间件
if r.preMiddlewares != nil {
for _, middlewareFunc := range r.preMiddlewares {
h = middlewareFunc(h)
}
}
h(ctx)
//后置中间件
for _, middlewareFunc := range r.preMiddlewares {
h = middlewareFunc(h)
}
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
e.httpRquestHandler(w, r)
}
func (e *Engine) httpRquestHandler(w http.ResponseWriter, r *http.Request) {
method := r.Method
for _, group := range e.routerGroups {
routerName := SubStringLast(r.RequestURI, "/"+group.name)
// get/1
node := group.treeNode.Get(routerName)
if node != nil && node.isEnd {
//路由匹配上了
ctx := &Context{
W: w,
R: r,
}
handle, ok := group.handleFuncMap[node.routerName][ANY]
if ok {
group.methodHandle(handle, ctx)
handle(ctx)
return
}
handle, ok = group.handleFuncMap[node.routerName][method]
if ok {
group.methodHandle(handle, ctx)
return
}
//method 进行匹配
w.WriteHeader(http.StatusMethodNotAllowed)
fmt.Fprintf(w, "%s %s not allow\n", r.RequestURI, method)
return
}
}
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "%s not found\n", r.RequestURI)
}
func (r *routerGroup) PostHandle(middlewareFunc ...MiddlewareFunc) {
r.postMiddlewares = append(r.postMiddlewares, middlewareFunc...)
}
func (r *routerGroup) methodHandle(h HandleFunc, ctx *Context) {
//前置中间件
if r.preMiddlewares != nil {
for _, middlewareFunc := range r.preMiddlewares {
h = middlewareFunc(h)
}
}
h(ctx)
//后置中间件
if r.postMiddlewares != nil {
for _, middlewareFunc := range r.postMiddlewares {
h = middlewareFunc(h)
}
}
h(ctx)
}
func main() {
engine := msgo.New()
g := engine.Group("user")
//g.Get("/hello", func(ctx *msgo.Context) {
// fmt.Fprintf(ctx.W, "%s Get欢迎学习GO自研框架", "lisus2000")
//})
g.PreHandle(func(next msgo.HandleFunc) msgo.HandleFunc {
return func(ctx *msgo.Context) {
fmt.Println("pre Handle")
next(ctx)
}
})
g.PostHandle(func(handleFunc msgo.HandleFunc) msgo.HandleFunc {
return func(ctx *msgo.Context) {
fmt.Println("post handle")
}
})
g.Get("/hello/get", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s Get欢迎学习GO自研框架", "lisus2000")
})
g.Post("/hello", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s Post欢迎学习GO自研框架", "lisus2000")
})
g.Post("/info", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s info功能", "lisus2000")
})
g.Get("/get/:id", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s get user 路径变量的值", "lisus2000")
})
engine.Run()
}
中间件的触发是等待用户处理函数的执行,也就是说前置中间件就可以完成后置中间件的功能,所以这里我们去除后置中间件
// 定义路由分组结构
type routerGroup struct {
name string
handleFuncMap map[string]map[string]HandleFunc
handlerMethodMap map[string][]string
treeNode *treeNode
middlewares []MiddlewareFunc
}
func (r *routerGroup) Use(middlewareFunc ...MiddlewareFunc) {
r.middlewares = append(r.middlewares, middlewareFunc...)
}
func (r *routerGroup) methodHandle(h HandleFunc, ctx *Context) {
//前置中间件
if r.middlewares != nil {
for _, middlewareFunc := range r.middlewares {
h = middlewareFunc(h)
}
}
h(ctx)
}
测试代码
func main() {
engine := msgo.New()
g := engine.Group("user")
//g.Get("/hello", func(ctx *msgo.Context) {
// fmt.Fprintf(ctx.W, "%s Get欢迎学习GO自研框架", "lisus2000")
//})
g.Use(func(next msgo.HandleFunc) msgo.HandleFunc {
return func(ctx *msgo.Context) {
fmt.Println("pre Handle")
next(ctx)
fmt.Println("post Handler")
}
})
g.Get("/hello/get", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s Get欢迎学习GO自研框架", "lisus2000")
})
g.Post("/hello", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s Post欢迎学习GO自研框架", "lisus2000")
})
g.Post("/info", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s info功能", "lisus2000")
})
g.Get("/get/:id", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s get user 路径变量的值", "lisus2000")
})
engine.Run()
}
func (r *routerGroup) methodHandle(name string, mehtod string, h HandleFunc, ctx *Context) {
//组通用中间件
if r.middlewares != nil {
for _, middlewareFunc := range r.middlewares {
h = middlewareFunc(h)
}
}
//路由级别中间件
middlewareFuncs := r.middlewaresFuncMap[name][mehtod]
if middlewareFuncs != nil {
for _, middlewareFunc := range middlewareFuncs {
h = middlewareFunc(h)
}
}
h(ctx)
}
func (r *routerGroup) handle(name string, method string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
_, ok := r.handleFuncMap[name]
if !ok {
r.handleFuncMap[name] = make(map[string]HandleFunc)
r.middlewaresFuncMap[name] = make(map[string][]MiddlewareFunc)
}
_, ok = r.handleFuncMap[name][method]
if ok {
panic("有重复的路由")
}
r.handleFuncMap[name][method] = handleFunc
r.middlewaresFuncMap[name][method] = append(r.middlewaresFuncMap[name][method], middlewareFunc...)
r.treeNode.Put(name)
}
func (r *routerGroup) Any(name string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, ANY, handleFunc, middlewareFunc...)
}
func (r *routerGroup) Get(name string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodGet, handleFunc, middlewareFunc...)
}
func (r *routerGroup) Post(name string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodPost, handleFunc, middlewareFunc...)
}
func (r *routerGroup) Delete(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodDelete, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Put(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodPut, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Patch(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodPatch, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Options(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodOptions, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Head(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodHead, handlerFunc, middlewareFunc...)
}
完整代码如下:
package msgo
import (
"fmt"
"log"
"net/http"
)
const ANY = "ANY"
type HandleFunc func(ctx *Context)
type MiddlewareFunc func(handleFunc HandleFunc) HandleFunc
// 定义路由分组结构
type routerGroup struct {
name string
handleFuncMap map[string]map[string]HandleFunc
middlewaresFuncMap map[string]map[string][]MiddlewareFunc
handlerMethodMap map[string][]string
treeNode *treeNode
middlewares []MiddlewareFunc
}
type router struct {
routerGroups []*routerGroup
}
// Group 分组方法
func (r *router) Group(name string) *routerGroup {
routerGroup := &routerGroup{
name: name,
handleFuncMap: make(map[string]map[string]HandleFunc),
middlewaresFuncMap: make(map[string]map[string][]MiddlewareFunc),
handlerMethodMap: make(map[string][]string),
treeNode: &treeNode{name: "/", children: make([]*treeNode, 0)},
}
r.routerGroups = append(r.routerGroups, routerGroup)
return routerGroup
}
func (r *routerGroup) Use(middlewareFunc ...MiddlewareFunc) {
r.middlewares = append(r.middlewares, middlewareFunc...)
}
func (r *routerGroup) methodHandle(name string, mehtod string, h HandleFunc, ctx *Context) {
//组通用中间件
if r.middlewares != nil {
for _, middlewareFunc := range r.middlewares {
h = middlewareFunc(h)
}
}
//路由级别中间件
middlewareFuncs := r.middlewaresFuncMap[name][mehtod]
if middlewareFuncs != nil {
for _, middlewareFunc := range middlewareFuncs {
h = middlewareFunc(h)
}
}
h(ctx)
}
func (r *routerGroup) handle(name string, method string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
_, ok := r.handleFuncMap[name]
if !ok {
r.handleFuncMap[name] = make(map[string]HandleFunc)
r.middlewaresFuncMap[name] = make(map[string][]MiddlewareFunc)
}
_, ok = r.handleFuncMap[name][method]
if ok {
panic("有重复的路由")
}
r.handleFuncMap[name][method] = handleFunc
r.middlewaresFuncMap[name][method] = append(r.middlewaresFuncMap[name][method], middlewareFunc...)
r.treeNode.Put(name)
}
func (r *routerGroup) Any(name string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, ANY, handleFunc, middlewareFunc...)
}
func (r *routerGroup) Get(name string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodGet, handleFunc, middlewareFunc...)
}
func (r *routerGroup) Post(name string, handleFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodPost, handleFunc, middlewareFunc...)
}
func (r *routerGroup) Delete(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodDelete, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Put(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodPut, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Patch(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodPatch, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Options(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodOptions, handlerFunc, middlewareFunc...)
}
func (r *routerGroup) Head(name string, handlerFunc HandleFunc, middlewareFunc ...MiddlewareFunc) {
r.handle(name, http.MethodHead, handlerFunc, middlewareFunc...)
}
type Engine struct {
router
}
func New() *Engine {
return &Engine{
router: router{},
}
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
e.httpRquestHandler(w, r)
}
func (e *Engine) Run() {
http.Handle("/", e)
err := http.ListenAndServe(":8111", nil)
if err != nil {
log.Fatal(err)
}
}
func (e *Engine) httpRquestHandler(w http.ResponseWriter, r *http.Request) {
method := r.Method
for _, group := range e.routerGroups {
routerName := SubStringLast(r.RequestURI, "/"+group.name)
// get/1
node := group.treeNode.Get(routerName)
if node != nil && node.isEnd {
//路由匹配上了
ctx := &Context{
W: w,
R: r,
}
handle, ok := group.handleFuncMap[node.routerName][ANY]
if ok {
group.methodHandle(node.routerName, ANY, handle, ctx)
handle(ctx)
return
}
handle, ok = group.handleFuncMap[node.routerName][method]
if ok {
group.methodHandle(node.routerName, method, handle, ctx)
return
}
//method 进行匹配
w.WriteHeader(http.StatusMethodNotAllowed)
fmt.Fprintf(w, "%s %s not allow\n", r.RequestURI, method)
return
}
}
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "%s not found\n", r.RequestURI)
}
路由级别测试代码
package main
import (
"fmt"
"msgo"
)
func Log(next msgo.HandleFunc) msgo.HandleFunc {
return func(ctx *msgo.Context) {
fmt.Println("打印请求参数")
next(ctx)
fmt.Println("返回执行")
}
}
func main() {
engine := msgo.New()
g := engine.Group("user")
//g.Get("/hello", func(ctx *msgo.Context) {
// fmt.Fprintf(ctx.W, "%s Get欢迎学习GO自研框架", "lisus2000")
//})
g.Use(func(next msgo.HandleFunc) msgo.HandleFunc {
return func(ctx *msgo.Context) {
fmt.Println("pre Handle")
next(ctx)
fmt.Println("post Handler")
}
})
g.Get("/hello/get", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s Get欢迎学习GO自研框架", "lisus2000")
}, Log)
g.Post("/hello", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s Post欢迎学习GO自研框架", "lisus2000")
})
g.Post("/info", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s info功能", "lisus2000")
})
g.Get("/get/:id", func(ctx *msgo.Context) {
fmt.Fprintf(ctx.W, "%s get user 路径变量的值", "lisus2000")
})
engine.Run()
}