Fix basic auth behaviour
This commit is contained in:
parent
8215c20ae6
commit
9c2691e8d0
@ -83,6 +83,14 @@ func NewWithConfig(config Config) echo.MiddlewareFunc {
|
||||
config.Skipper = DefaultConfig.Skipper
|
||||
}
|
||||
|
||||
if len(config.Mounts) == 0 {
|
||||
config.Mounts = append(config.Mounts, "/")
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = log.New("")
|
||||
}
|
||||
|
||||
mw := iammiddleware{
|
||||
iam: config.IAM,
|
||||
mounts: config.Mounts,
|
||||
@ -168,6 +176,8 @@ func NewWithConfig(config Config) echo.MiddlewareFunc {
|
||||
if err == ErrUnauthorized {
|
||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, "Basic realm=datarhei-core")
|
||||
return api.Err(http.StatusUnauthorized, "Unauthorized", "%s", err)
|
||||
} else if err == ErrBadRequest {
|
||||
return api.Err(http.StatusBadRequest, "Bad request", "%s", err)
|
||||
}
|
||||
|
||||
return api.Err(http.StatusForbidden, "Forbidden", "%s", err)
|
||||
@ -199,6 +209,7 @@ func NewWithConfig(config Config) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
var ErrUnauthorized = errors.New("unauthorized")
|
||||
var ErrBadRequest = errors.New("bad request")
|
||||
|
||||
func (m *iammiddleware) findIdentityFromBasicAuth(c echo.Context) (iam.IdentityVerifier, error) {
|
||||
basic := "basic"
|
||||
@ -206,20 +217,14 @@ func (m *iammiddleware) findIdentityFromBasicAuth(c echo.Context) (iam.IdentityV
|
||||
l := len(basic)
|
||||
|
||||
if len(auth) == 0 {
|
||||
method := c.Request().Method
|
||||
if method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions {
|
||||
return nil, nil
|
||||
path := c.Request().URL.Path
|
||||
domain := m.findDomainFromFilesystem(path)
|
||||
if len(domain) == 0 {
|
||||
domain = "$none"
|
||||
}
|
||||
|
||||
path := c.Request().URL.Path
|
||||
for _, m := range m.mounts {
|
||||
if m == "/" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path, m+"/") {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
if !m.iam.Enforce("$anon", domain, "fs:"+path, c.Request().Method) {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
@ -233,7 +238,7 @@ func (m *iammiddleware) findIdentityFromBasicAuth(c echo.Context) (iam.IdentityV
|
||||
// instead should be treated as invalid client input
|
||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, ErrBadRequest
|
||||
}
|
||||
|
||||
cred := string(b)
|
||||
@ -251,7 +256,7 @@ func (m *iammiddleware) findIdentityFromBasicAuth(c echo.Context) (iam.IdentityV
|
||||
"path": c.Request().URL.Path,
|
||||
"method": c.Request().Method,
|
||||
}).WithError(err).Log("identity not found")
|
||||
return nil, fmt.Errorf("invalid username or password")
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
if ok, err := identity.VerifyServiceBasicAuth(password); !ok {
|
||||
@ -259,7 +264,7 @@ func (m *iammiddleware) findIdentityFromBasicAuth(c echo.Context) (iam.IdentityV
|
||||
"path": c.Request().URL.Path,
|
||||
"method": c.Request().Method,
|
||||
}).WithError(err).Log("wrong password")
|
||||
return nil, fmt.Errorf("invalid username or password")
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
return identity, nil
|
||||
@ -424,7 +429,11 @@ func (m *iammiddleware) findDomainFromFilesystem(path string) string {
|
||||
// Remove it from the path and split it into components: foobar file.txt
|
||||
// Check if foobar a known domain. If yes, return it. If not, return empty domain.
|
||||
for _, mount := range m.mounts {
|
||||
prefix := filepath.Clean(mount) + "/"
|
||||
prefix := filepath.Clean(mount)
|
||||
if prefix != "/" {
|
||||
prefix += "/"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
elements := strings.Split(strings.TrimPrefix(path, prefix), "/")
|
||||
if m.iam.HasDomain(elements[0]) {
|
||||
|
||||
372
http/middleware/iam/iam_test.go
Normal file
372
http/middleware/iam/iam_test.go
Normal file
@ -0,0 +1,372 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/datarhei/core/v16/http/api"
|
||||
apihandler "github.com/datarhei/core/v16/http/handler/api"
|
||||
"github.com/datarhei/core/v16/http/validator"
|
||||
"github.com/datarhei/core/v16/iam"
|
||||
"github.com/datarhei/core/v16/io/fs"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var basic string = "Basic"
|
||||
|
||||
func getIAM() (iam.IAM, error) {
|
||||
dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := iam.NewIAM(iam.Config{
|
||||
FS: dummyfs,
|
||||
Superuser: iam.User{
|
||||
Name: "admin",
|
||||
},
|
||||
JWTRealm: "datarhei-core",
|
||||
JWTSecret: "1234567890",
|
||||
Logger: nil,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i.CreateIdentity(iam.User{
|
||||
Name: "foobar",
|
||||
Auth: iam.UserAuth{
|
||||
API: iam.UserAuthAPI{
|
||||
Userpass: iam.UserAuthPassword{
|
||||
Enable: true,
|
||||
Password: "secret",
|
||||
},
|
||||
},
|
||||
Services: iam.UserAuthServices{
|
||||
Basic: iam.UserAuthPassword{
|
||||
Enable: true,
|
||||
Password: "secret",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func TestNoIAM(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
h := New()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
err := h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he := err.(api.Error)
|
||||
require.Equal(t, http.StatusForbidden, he.Code)
|
||||
}
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
iam, err := getIAM()
|
||||
require.NoError(t, err)
|
||||
|
||||
iam.AddPolicy("foobar", "$none", "fs:/**", "ANY")
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
h := NewWithConfig(Config{
|
||||
IAM: iam,
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// No credentials
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he := err.(api.Error)
|
||||
require.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
require.Equal(t, basic+` realm=datarhei-core`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
|
||||
// Valid credentials
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("foobar:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
require.NoError(t, h(c))
|
||||
|
||||
// Case-insensitive header scheme
|
||||
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("foobar:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
require.NoError(t, h(c))
|
||||
|
||||
// Invalid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("foobar:invalid-password"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(api.Error)
|
||||
require.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
require.Equal(t, basic+` realm=datarhei-core`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
|
||||
// Invalid base64 string
|
||||
auth = basic + " invalidString"
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(api.Error)
|
||||
require.Equal(t, http.StatusBadRequest, he.Code)
|
||||
|
||||
// Missing Authorization header
|
||||
req.Header.Del(echo.HeaderAuthorization)
|
||||
he = h(c).(api.Error)
|
||||
require.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(api.Error)
|
||||
require.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
}
|
||||
|
||||
func TestFindDomainFromFilesystem(t *testing.T) {
|
||||
iam, err := getIAM()
|
||||
require.NoError(t, err)
|
||||
|
||||
iam.AddPolicy("$anon", "$none", "fs:/**", "ANY")
|
||||
iam.AddPolicy("foobar", "group", "fs:/group/**", "ANY")
|
||||
iam.AddPolicy("foobar", "anothergroup", "fs:/memfs/anothergroup/**", "ANY")
|
||||
|
||||
mw := &iammiddleware{
|
||||
iam: iam,
|
||||
mounts: []string{"/memfs", "/"},
|
||||
}
|
||||
|
||||
domain := mw.findDomainFromFilesystem("/")
|
||||
require.Equal(t, "", domain)
|
||||
|
||||
domain = mw.findDomainFromFilesystem("/group/bla")
|
||||
require.Equal(t, "group", domain)
|
||||
|
||||
domain = mw.findDomainFromFilesystem("/anothergroup/bla")
|
||||
require.Equal(t, "anothergroup", domain)
|
||||
|
||||
domain = mw.findDomainFromFilesystem("/memfs/anothergroup/bla")
|
||||
require.Equal(t, "anothergroup", domain)
|
||||
}
|
||||
|
||||
func TestBasicAuthDomain(t *testing.T) {
|
||||
iam, err := getIAM()
|
||||
require.NoError(t, err)
|
||||
|
||||
iam.AddPolicy("$anon", "$none", "fs:/**", "ANY")
|
||||
iam.AddPolicy("foobar", "group", "fs:/group/**", "ANY")
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
h := NewWithConfig(Config{
|
||||
IAM: iam,
|
||||
Mounts: []string{"/"},
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// No credentials
|
||||
require.NoError(t, h(c))
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/group/bla", nil)
|
||||
c = e.NewContext(req, res)
|
||||
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he := err.(api.Error)
|
||||
require.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
require.Equal(t, basic+` realm=datarhei-core`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
|
||||
// Valid credentials
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("foobar:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
require.NoError(t, h(c))
|
||||
|
||||
// Allow anonymous group read access
|
||||
iam.AddPolicy("$anon", "group", "fs:/group/**", "GET")
|
||||
|
||||
req.Header.Del(echo.HeaderAuthorization)
|
||||
require.NoError(t, h(c))
|
||||
}
|
||||
|
||||
func TestAPILoginAndRefresh(t *testing.T) {
|
||||
iam, err := getIAM()
|
||||
require.NoError(t, err)
|
||||
|
||||
iam.AddPolicy("foobar", "$none", "api:/**", "ANY")
|
||||
|
||||
jwthandler := apihandler.NewJWT(iam)
|
||||
|
||||
e := echo.New()
|
||||
e.Validator = validator.New()
|
||||
res := httptest.NewRecorder()
|
||||
h := NewWithConfig(Config{
|
||||
IAM: iam,
|
||||
Mounts: []string{"/"},
|
||||
})(func(c echo.Context) error {
|
||||
if c.Request().Method == http.MethodPost {
|
||||
if c.Request().URL.Path == "/api/login" {
|
||||
return jwthandler.Login(c)
|
||||
}
|
||||
}
|
||||
|
||||
if c.Request().Method == http.MethodGet {
|
||||
if c.Request().URL.Path == "/api/login/refresh" {
|
||||
return jwthandler.Refresh(c)
|
||||
}
|
||||
}
|
||||
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/login", nil)
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
// No credentials
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he := err.(api.Error)
|
||||
require.Equal(t, http.StatusForbidden, he.Code)
|
||||
|
||||
// Wrong password
|
||||
login := api.Login{
|
||||
Username: "foobar",
|
||||
Password: "nosecret",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(login)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/login", bytes.NewReader(data))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
c = e.NewContext(req, res)
|
||||
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he = err.(api.Error)
|
||||
require.Equal(t, http.StatusForbidden, he.Code)
|
||||
|
||||
// Wrong username
|
||||
login = api.Login{
|
||||
Username: "foobaz",
|
||||
Password: "secret",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(login)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/login", bytes.NewReader(data))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
c = e.NewContext(req, res)
|
||||
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he = err.(api.Error)
|
||||
require.Equal(t, http.StatusForbidden, he.Code)
|
||||
|
||||
// Correct credentials
|
||||
login = api.Login{
|
||||
Username: "foobar",
|
||||
Password: "secret",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(login)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/login", bytes.NewReader(data))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
c = e.NewContext(req, res)
|
||||
res.Body.Reset()
|
||||
|
||||
err = h(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err = io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwt := api.JWT{}
|
||||
err = json.Unmarshal(data, &jwt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// No JWT
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/some/endpoint", nil)
|
||||
c = e.NewContext(req, res)
|
||||
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he = err.(api.Error)
|
||||
require.Equal(t, http.StatusForbidden, he.Code)
|
||||
|
||||
// With invalid JWT
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid")
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
// With refresh JWT
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.RefreshToken)
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
// With access JWT
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.AccessToken)
|
||||
err = h(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh JWT
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/login/refresh", nil)
|
||||
c = e.NewContext(req, res)
|
||||
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he = err.(api.Error)
|
||||
require.Equal(t, http.StatusForbidden, he.Code)
|
||||
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.AccessToken)
|
||||
err = h(c)
|
||||
require.Error(t, err)
|
||||
|
||||
he = err.(api.Error)
|
||||
require.Equal(t, http.StatusForbidden, he.Code)
|
||||
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.RefreshToken)
|
||||
res.Body.Reset()
|
||||
err = h(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err = io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtrefresh := api.JWTRefresh{}
|
||||
err = json.Unmarshal(data, &jwtrefresh)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/some/endpoint", nil)
|
||||
c = e.NewContext(req, res)
|
||||
|
||||
// With new access JWT
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwtrefresh.AccessToken)
|
||||
err = h(c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@ -98,10 +98,8 @@ func (am *access) ListPolicies(username, domain, resource, actions string) [][]s
|
||||
}
|
||||
|
||||
func (am *access) HasGroup(name string) bool {
|
||||
groups, err := am.enforcer.GetAllDomains()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
groups := am.adapter.getAllGroups()
|
||||
fmt.Printf("groups: %+v\n", groups)
|
||||
|
||||
for _, g := range groups {
|
||||
if g == name {
|
||||
|
||||
@ -480,6 +480,20 @@ func (a *adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (a *adapter) getAllGroups() []string {
|
||||
names := []string{}
|
||||
|
||||
for _, group := range a.groups {
|
||||
if group.Name[0] == '$' {
|
||||
continue
|
||||
}
|
||||
|
||||
names = append(names, group.Name)
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
Name string `json:"name"`
|
||||
Roles map[string][]Role `json:"roles"`
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user