package middleware import ( "context" "net/http" "sync" "time" "github.com/labstack/echo/v4" ) // --------------------------------------------------------------------------------------------------------------- // WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING // WARNING: Timeout middleware causes more problems than it solves. // WARNING: This middleware should be first middleware as it messes with request Writer and could cause data race if // it is in other position // // Depending on out requirements you could be better of setting timeout to context and // check its deadline from handler. // // For example: create middleware to set timeout to context // func RequestTimeout(timeout time.Duration) echo.MiddlewareFunc { // return func(next echo.HandlerFunc) echo.HandlerFunc { // return func(c echo.Context) error { // timeoutCtx, cancel := context.WithTimeout(c.Request().Context(), timeout) // c.SetRequest(c.Request().WithContext(timeoutCtx)) // defer cancel() // return next(c) // } // } //} // // Create handler that checks for context deadline and runs actual task in separate coroutine // Note: separate coroutine may not be even if you do not want to process continue executing and // just want to stop long-running handler to stop and you are using "context aware" methods (ala db queries with ctx) // e.GET("/", func(c echo.Context) error { // // doneCh := make(chan error) // go func(ctx context.Context) { // doneCh <- myPossiblyLongRunningBackgroundTaskWithCtx(ctx) // }(c.Request().Context()) // // select { // wait for task to finish or context to timeout/cancelled // case err := <-doneCh: // if err != nil { // return err // } // return c.String(http.StatusOK, "OK") // case <-c.Request().Context().Done(): // if c.Request().Context().Err() == context.DeadlineExceeded { // return c.String(http.StatusServiceUnavailable, "timeout") // } // return c.Request().Context().Err() // } // // }) // // TimeoutConfig defines the config for Timeout middleware. type TimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code // It can be used to define a custom timeout error message ErrorMessage string // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after // request timeouted and we already had sent the error code (503) and message response to the client. // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` OnTimeoutRouteErrorHandler func(err error, c echo.Context) // Timeout configures a timeout for the middleware, defaults to 0 for no timeout // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output // difference over 500microseconds (0.5millisecond) response seems to be reliable Timeout time.Duration } var ( // DefaultTimeoutConfig is the default Timeout middleware config. DefaultTimeoutConfig = TimeoutConfig{ Skipper: DefaultSkipper, Timeout: 0, ErrorMessage: "", } ) // Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler // call runs for longer than its time limit. NB: timeout does not stop handler execution. func Timeout() echo.MiddlewareFunc { return TimeoutWithConfig(DefaultTimeoutConfig) } // TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration. func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { mw, err := config.ToMiddleware() if err != nil { panic(err) } return mw } // ToMiddleware converts Config to middleware or returns an error for invalid configuration func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultTimeoutConfig.Skipper } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) || config.Timeout == 0 { return next(c) } errChan := make(chan error, 1) handlerWrapper := echoHandlerFuncWrapper{ writer: &ignorableWriter{ResponseWriter: c.Response().Writer}, ctx: c, handler: next, errChan: errChan, errHandler: config.OnTimeoutRouteErrorHandler, } handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) handler.ServeHTTP(handlerWrapper.writer, c.Request()) select { case err := <-errChan: return err default: return nil } } }, nil } type echoHandlerFuncWrapper struct { writer *ignorableWriter ctx echo.Context handler echo.HandlerFunc errHandler func(err error, c echo.Context) errChan chan error } func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // replace echo.Context Request with the one provided by TimeoutHandler to let later middlewares/handler on the chain // handle properly it's cancellation t.ctx.SetRequest(r) // replace writer with TimeoutHandler custom one. This will guarantee that // `writes by h to its ResponseWriter will return ErrHandlerTimeout.` originalWriter := t.ctx.Response().Writer t.ctx.Response().Writer = rw // in case of panic we restore original writer and call panic again // so it could be handled with global middleware Recover() defer func() { if err := recover(); err != nil { t.ctx.Response().Writer = originalWriter panic(err) } }() err := t.handler(t.ctx) if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { if err != nil && t.errHandler != nil { t.errHandler(err, t.ctx) } return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers } if err != nil { // This is needed as `http.TimeoutHandler` will write status code by itself on error and after that our tries to write // status code will not work anymore as Echo.Response thinks it has been already "committed" and further writes // create errors in log about `superfluous response.WriteHeader call from` t.writer.Ignore(true) t.ctx.Response().Writer = originalWriter // make sure we restore writer before we signal original coroutine about the error // we pass error from handler to middlewares up in handler chain to act on it if needed. t.errChan <- err return } // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client // and should not anymore send additional headers/data // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body t.ctx.Response().Writer = originalWriter } // ignorableWriter is ResponseWriter implementations that allows us to mark writer to ignore further write calls. This // is handy in cases when you do not have direct control of code being executed (3rd party middleware) but want to make // sure that external code will not be able to write response to the client. // Writer is coroutine safe for writes. type ignorableWriter struct { http.ResponseWriter lock sync.Mutex ignoreWrites bool } func (w *ignorableWriter) Ignore(ignore bool) { w.lock.Lock() w.ignoreWrites = ignore w.lock.Unlock() } func (w *ignorableWriter) WriteHeader(code int) { w.lock.Lock() defer w.lock.Unlock() if w.ignoreWrites { return } w.ResponseWriter.WriteHeader(code) } func (w *ignorableWriter) Write(b []byte) (int, error) { w.lock.Lock() defer w.lock.Unlock() if w.ignoreWrites { return len(b), nil } return w.ResponseWriter.Write(b) }