diff --git a/http/middleware/compress/compress.go b/http/middleware/compress/compress.go index 68521b7a..aa5017b2 100644 --- a/http/middleware/compress/compress.go +++ b/http/middleware/compress/compress.go @@ -20,12 +20,15 @@ type Config struct { Skipper middleware.Skipper // Compression level. - // Optional. Default value -1. + // Optional. Default value 0. Level Level // Length threshold before compression // is used. Optional. Default value 0 MinLength int + + // Schemes is a list of enabled compressiond. Optional. Default [GzipScheme, ZstdScheme] + Schemes []Scheme } type Compression interface { @@ -78,6 +81,7 @@ var DefaultConfig = Config{ Skipper: middleware.DefaultSkipper, Level: DefaultCompression, MinLength: 0, + Schemes: []Scheme{GzipScheme, ZstdScheme}, } // ContentTypesSkipper returns a Skipper based on the list of content types @@ -125,9 +129,41 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { config.MinLength = DefaultConfig.MinLength } - gzipPool := NewGzip(config.Level) - brotliPool := NewBrotli(config.Level) - zstdPool := NewZstd(config.Level) + if len(config.Schemes) == 0 { + config.Schemes = DefaultConfig.Schemes + } + + gzipEnable := false + brotliEnable := false + zstdEnable := false + + for _, s := range config.Schemes { + switch s { + case GzipScheme: + gzipEnable = true + case BrotliScheme: + brotliEnable = true + case ZstdScheme: + zstdEnable = true + } + } + + var gzipPool Compression + var brotliPool Compression + var zstdPool Compression + + if gzipEnable { + gzipPool = NewGzip(config.Level) + } + + if brotliEnable { + brotliPool = NewBrotli(config.Level) + } + + if zstdEnable { + zstdPool = NewZstd(config.Level) + } + bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -143,13 +179,13 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { var pool Compression var scheme Scheme - if strings.Contains(encodings, ZstdScheme.String()) { + if zstdEnable && strings.Contains(encodings, ZstdScheme.String()) { pool = zstdPool scheme = ZstdScheme - } else if strings.Contains(encodings, BrotliScheme.String()) { + } else if brotliEnable && strings.Contains(encodings, BrotliScheme.String()) { pool = brotliPool scheme = BrotliScheme - } else if strings.Contains(encodings, GzipScheme.String()) { + } else if gzipEnable && strings.Contains(encodings, GzipScheme.String()) { pool = gzipPool scheme = GzipScheme } diff --git a/http/middleware/compress/compress_test.go b/http/middleware/compress/compress_test.go index 2ee10743..3e4c6047 100644 --- a/http/middleware/compress/compress_test.go +++ b/http/middleware/compress/compress_test.go @@ -81,7 +81,7 @@ func TestCompress(t *testing.T) { c := e.NewContext(req, rec) // Skip if no Accept-Encoding header - h := New()(func(c echo.Context) error { + h := NewWithConfig(Config{Schemes: []Scheme{scheme}})(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) @@ -113,7 +113,7 @@ func TestCompress(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec) - New()(func(c echo.Context) error { + NewWithConfig(Config{Schemes: []Scheme{scheme}})(func(c echo.Context) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") @@ -150,7 +150,7 @@ func TestCompressWithMinLength(t *testing.T) { t.Run(scheme.String(), func(t *testing.T) { e := echo.New() // Invalid level - e.Use(NewWithConfig(Config{MinLength: 5})) + e.Use(NewWithConfig(Config{MinLength: 5, Schemes: []Scheme{scheme}})) e.GET("/", func(c echo.Context) error { c.Response().Write([]byte("test")) return nil @@ -192,7 +192,7 @@ func TestCompressNoContent(t *testing.T) { req.Header.Set(echo.HeaderAcceptEncoding, scheme.String()) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := New()(func(c echo.Context) error { + h := NewWithConfig(Config{Schemes: []Scheme{scheme}})(func(c echo.Context) error { return c.NoContent(http.StatusNoContent) }) if assert.NoError(t, h(c)) { @@ -214,7 +214,7 @@ func TestCompressEmpty(t *testing.T) { req.Header.Set(echo.HeaderAcceptEncoding, scheme.String()) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := New()(func(c echo.Context) error { + h := NewWithConfig(Config{Schemes: []Scheme{scheme}})(func(c echo.Context) error { return c.String(http.StatusOK, "") }) if assert.NoError(t, h(c)) { @@ -237,7 +237,7 @@ func TestCompressErrorReturned(t *testing.T) { for scheme := range schemes { t.Run(scheme.String(), func(t *testing.T) { e := echo.New() - e.Use(New()) + e.Use(NewWithConfig(Config{Schemes: []Scheme{scheme}})) e.GET("/", func(c echo.Context) error { return echo.ErrNotFound }) @@ -258,7 +258,7 @@ func TestCompressWithStatic(t *testing.T) { for scheme, reader := range schemes { t.Run(scheme.String(), func(t *testing.T) { e := echo.New() - e.Use(New()) + e.Use(NewWithConfig(Config{Schemes: []Scheme{scheme}})) e.Static("/test", "./") req := httptest.NewRequest(http.MethodGet, "/test/compress.go", nil) req.Header.Set(echo.HeaderAcceptEncoding, scheme.String()) @@ -294,8 +294,8 @@ func BenchmarkCompress(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, scheme.String()) - h := New()(func(c echo.Context) error { - c.Response().Write([]byte("test")) + h := NewWithConfig(Config{Level: BestSpeed, Schemes: []Scheme{scheme}})(func(c echo.Context) error { + c.Response().Write([]byte("testtesttesttesttesttesttesttesttesttesttesttesttest")) return nil }) @@ -324,7 +324,7 @@ func BenchmarkCompressLarge(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, scheme.String()) - h := New()(func(c echo.Context) error { + h := NewWithConfig(Config{Level: BestSpeed, Schemes: []Scheme{scheme}})(func(c echo.Context) error { c.Response().Write(data) return nil })