Add cluster IAM, update dependencies

This commit is contained in:
Ingo Oppermann 2023-05-30 21:24:12 +02:00
parent 65e01112a1
commit 89cc42c4b9
No known key found for this signature in database
GPG Key ID: 2AB32426E9DD229E
565 changed files with 56157 additions and 27732 deletions

View File

@ -391,6 +391,68 @@ func (a *api) start() error {
a.sessions = sessions
}
if cfg.Cluster.Enable {
scheme := "http://"
address := cfg.Address
if cfg.TLS.Enable {
scheme = "https://"
address = cfg.TLS.Address
}
host, port, err := gonet.SplitHostPort(address)
if err != nil {
return fmt.Errorf("invalid core address: %s: %w", address, err)
}
if len(host) == 0 {
chost, _, err := gonet.SplitHostPort(cfg.Cluster.Address)
if err != nil {
return fmt.Errorf("invalid cluster address: %s: %w", cfg.Cluster.Address, err)
}
if len(chost) == 0 {
return fmt.Errorf("invalid cluster address: %s: %w", cfg.Cluster.Address, err)
}
host = chost
}
peers := []cluster.Peer{}
for _, p := range cfg.Cluster.Peers {
id, address, found := strings.Cut(p, "@")
if !found {
continue
}
peers = append(peers, cluster.Peer{
ID: id,
Address: address,
})
}
cluster, err := cluster.New(cluster.ClusterConfig{
ID: cfg.ID,
Name: cfg.Name,
Path: filepath.Join(cfg.DB.Dir, "cluster"),
Bootstrap: cfg.Cluster.Bootstrap,
Recover: cfg.Cluster.Recover,
Address: cfg.Cluster.Address,
Peers: peers,
CoreAPIAddress: scheme + gonet.JoinHostPort(host, port),
CoreAPIUsername: cfg.API.Auth.Username,
CoreAPIPassword: cfg.API.Auth.Password,
IPLimiter: a.sessionsLimiter,
Logger: a.log.logger.core.WithComponent("Cluster"),
})
if err != nil {
return fmt.Errorf("unable to create cluster: %w", err)
}
a.cluster = cluster
}
{
superuser := iamidentity.User{
Name: cfg.API.Auth.Username,
@ -421,141 +483,151 @@ func (a *api) start() error {
}
}
fs, err := fs.NewRootedDiskFilesystem(fs.RootedDiskConfig{
Root: filepath.Join(cfg.DB.Dir, "iam"),
})
if err != nil {
return err
}
secret := rand.String(32)
if len(cfg.API.Auth.JWT.Secret) != 0 {
secret = cfg.API.Auth.Username + cfg.API.Auth.Password + cfg.API.Auth.JWT.Secret
}
policyAdapter, err := iamaccess.NewJSONAdapter(fs, "./policy.json", nil)
if err != nil {
return err
}
var manager iam.IAM = nil
identityAdapter, err := iamidentity.NewJSONAdapter(fs, "./users.json", nil)
if err != nil {
return err
}
manager, err := iam.NewIAM(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: superuser,
JWTRealm: "datarhei-core",
JWTSecret: secret,
Logger: a.log.logger.core.WithComponent("IAM"),
})
if err != nil {
return fmt.Errorf("iam: %w", err)
}
// Check if there are already file created by IAM. If not, create policies
// and users based on the config in order to mimic the behaviour before IAM.
if len(fs.List("/", "/*.json")) == 0 {
policies := []iamaccess.Policy{
{
Name: "$anon",
Domain: "$none",
Resource: "fs:/**",
Actions: []string{"GET", "HEAD", "OPTIONS"},
},
{
Name: "$anon",
Domain: "$none",
Resource: "api:/api",
Actions: []string{"GET", "HEAD", "OPTIONS"},
},
{
Name: "$anon",
Domain: "$none",
Resource: "api:/api/v3/widget/process/**",
Actions: []string{"GET", "HEAD", "OPTIONS"},
},
if a.cluster != nil {
var err error = nil
manager, err = a.cluster.IAM(superuser, "datarhei-core", secret)
if err != nil {
return err
}
} else {
fs, err := fs.NewRootedDiskFilesystem(fs.RootedDiskConfig{
Root: filepath.Join(cfg.DB.Dir, "iam"),
})
if err != nil {
return err
}
users := map[string]iamidentity.User{}
policyAdapter, err := iamaccess.NewJSONAdapter(fs, "./policy.json", nil)
if err != nil {
return err
}
if cfg.Storage.Memory.Auth.Enable && cfg.Storage.Memory.Auth.Username != superuser.Name {
users[cfg.Storage.Memory.Auth.Username] = iamidentity.User{
Name: cfg.Storage.Memory.Auth.Username,
Auth: iamidentity.UserAuth{
Services: iamidentity.UserAuthServices{
Basic: []string{cfg.Storage.Memory.Auth.Password},
},
identityAdapter, err := iamidentity.NewJSONAdapter(fs, "./users.json", nil)
if err != nil {
return err
}
manager, err = iam.New(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: superuser,
JWTRealm: "datarhei-core",
JWTSecret: secret,
Logger: a.log.logger.core.WithComponent("IAM"),
})
if err != nil {
return fmt.Errorf("iam: %w", err)
}
// Check if there are already file created by IAM. If not, create policies
// and users based on the config in order to mimic the behaviour before IAM.
if len(fs.List("/", "/*.json")) == 0 {
policies := []iamaccess.Policy{
{
Name: "$anon",
Domain: "$none",
Resource: "fs:/**",
Actions: []string{"GET", "HEAD", "OPTIONS"},
},
{
Name: "$anon",
Domain: "$none",
Resource: "api:/api",
Actions: []string{"GET", "HEAD", "OPTIONS"},
},
{
Name: "$anon",
Domain: "$none",
Resource: "api:/api/v3/widget/process/**",
Actions: []string{"GET", "HEAD", "OPTIONS"},
},
}
policies = append(policies, iamaccess.Policy{
Name: cfg.Storage.Memory.Auth.Username,
Domain: "$none",
Resource: "fs:/memfs/**",
Actions: []string{"ANY"},
})
}
users := map[string]iamidentity.User{}
for _, s := range cfg.Storage.S3 {
if s.Auth.Enable && s.Auth.Username != superuser.Name {
user, ok := users[s.Auth.Username]
if !ok {
users[s.Auth.Username] = iamidentity.User{
Name: s.Auth.Username,
Auth: iamidentity.UserAuth{
Services: iamidentity.UserAuthServices{
Basic: []string{s.Auth.Password},
},
if cfg.Storage.Memory.Auth.Enable && cfg.Storage.Memory.Auth.Username != superuser.Name {
users[cfg.Storage.Memory.Auth.Username] = iamidentity.User{
Name: cfg.Storage.Memory.Auth.Username,
Auth: iamidentity.UserAuth{
Services: iamidentity.UserAuthServices{
Basic: []string{cfg.Storage.Memory.Auth.Password},
},
}
} else {
user.Auth.Services.Basic = append(user.Auth.Services.Basic, s.Auth.Password)
users[s.Auth.Username] = user
},
}
policies = append(policies, iamaccess.Policy{
Name: s.Auth.Username,
Name: cfg.Storage.Memory.Auth.Username,
Domain: "$none",
Resource: "fs:" + s.Mountpoint + "/**",
Resource: "fs:/memfs/**",
Actions: []string{"ANY"},
})
}
}
if cfg.RTMP.Enable && len(cfg.RTMP.Token) == 0 {
policies = append(policies, iamaccess.Policy{
Name: "$anon",
Domain: "$none",
Resource: "rtmp:/**",
Actions: []string{"ANY"},
})
}
for _, s := range cfg.Storage.S3 {
if s.Auth.Enable && s.Auth.Username != superuser.Name {
user, ok := users[s.Auth.Username]
if !ok {
users[s.Auth.Username] = iamidentity.User{
Name: s.Auth.Username,
Auth: iamidentity.UserAuth{
Services: iamidentity.UserAuthServices{
Basic: []string{s.Auth.Password},
},
},
}
} else {
user.Auth.Services.Basic = append(user.Auth.Services.Basic, s.Auth.Password)
users[s.Auth.Username] = user
}
if cfg.SRT.Enable && len(cfg.SRT.Token) == 0 {
policies = append(policies, iamaccess.Policy{
Name: "$anon",
Domain: "$none",
Resource: "srt:**",
Actions: []string{"ANY"},
})
}
for _, user := range users {
if _, err := manager.GetIdentity(user.Name); err == nil {
continue
policies = append(policies, iamaccess.Policy{
Name: s.Auth.Username,
Domain: "$none",
Resource: "fs:" + s.Mountpoint + "/**",
Actions: []string{"ANY"},
})
}
}
err := manager.CreateIdentity(user)
if err != nil {
return fmt.Errorf("iam: %w", err)
if cfg.RTMP.Enable && len(cfg.RTMP.Token) == 0 {
policies = append(policies, iamaccess.Policy{
Name: "$anon",
Domain: "$none",
Resource: "rtmp:/**",
Actions: []string{"ANY"},
})
}
}
for _, policy := range policies {
manager.AddPolicy(policy.Name, policy.Domain, policy.Resource, policy.Actions)
if cfg.SRT.Enable && len(cfg.SRT.Token) == 0 {
policies = append(policies, iamaccess.Policy{
Name: "$anon",
Domain: "$none",
Resource: "srt:**",
Actions: []string{"ANY"},
})
}
for _, user := range users {
if _, err := manager.GetIdentity(user.Name); err == nil {
continue
}
err := manager.CreateIdentity(user)
if err != nil {
return fmt.Errorf("iam: %w", err)
}
}
for _, policy := range policies {
manager.AddPolicy(policy.Name, policy.Domain, policy.Resource, policy.Actions)
}
}
}
@ -845,68 +917,6 @@ func (a *api) start() error {
a.restream = restream
if cfg.Cluster.Enable {
scheme := "http://"
address := cfg.Address
if cfg.TLS.Enable {
scheme = "https://"
address = cfg.TLS.Address
}
host, port, err := gonet.SplitHostPort(address)
if err != nil {
return fmt.Errorf("invalid core address: %s: %w", address, err)
}
if len(host) == 0 {
chost, _, err := gonet.SplitHostPort(cfg.Cluster.Address)
if err != nil {
return fmt.Errorf("invalid cluster address: %s: %w", cfg.Cluster.Address, err)
}
if len(chost) == 0 {
return fmt.Errorf("invalid cluster address: %s: %w", cfg.Cluster.Address, err)
}
host = chost
}
peers := []cluster.Peer{}
for _, p := range cfg.Cluster.Peers {
id, address, found := strings.Cut(p, "@")
if !found {
continue
}
peers = append(peers, cluster.Peer{
ID: id,
Address: address,
})
}
cluster, err := cluster.New(cluster.ClusterConfig{
ID: cfg.ID,
Name: cfg.Name,
Path: filepath.Join(cfg.DB.Dir, "cluster"),
Bootstrap: cfg.Cluster.Bootstrap,
Recover: cfg.Cluster.Recover,
Address: cfg.Cluster.Address,
Peers: peers,
CoreAPIAddress: scheme + gonet.JoinHostPort(host, port),
CoreAPIUsername: cfg.API.Auth.Username,
CoreAPIPassword: cfg.API.Auth.Password,
IPLimiter: a.sessionsLimiter,
Logger: a.log.logger.core.WithComponent("Cluster"),
})
if err != nil {
return fmt.Errorf("unable to create cluster: %w", err)
}
a.cluster = cluster
}
metrics, err := monitor.NewHistory(monitor.HistoryConfig{
Enable: cfg.Metrics.Enable,
Timerange: time.Duration(cfg.Metrics.Range) * time.Second,

View File

@ -18,6 +18,8 @@ import (
"github.com/datarhei/core/v16/ffmpeg"
"github.com/datarhei/core/v16/ffmpeg/skills"
"github.com/datarhei/core/v16/iam"
iamaccess "github.com/datarhei/core/v16/iam/access"
iamidentity "github.com/datarhei/core/v16/iam/identity"
"github.com/datarhei/core/v16/io/fs"
"github.com/datarhei/core/v16/restream"
"github.com/datarhei/core/v16/restream/app"
@ -1453,9 +1455,20 @@ func probeInput(binary string, config app.Config) app.Probe {
return app.Probe{}
}
iam, _ := iam.NewIAM(iam.Config{
FS: dummyfs,
Superuser: iam.User{
policyAdapter, err := iamaccess.NewJSONAdapter(dummyfs, "./policy.json", nil)
if err != nil {
return app.Probe{}
}
identityAdapter, err := iamidentity.NewJSONAdapter(dummyfs, "./users.json", nil)
if err != nil {
return app.Probe{}
}
iam, _ := iam.New(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: iamidentity.User{
Name: "foobar",
},
JWTRealm: "",

View File

@ -101,7 +101,7 @@ func NewAPI(config APIConfig) (API, error) {
return c.JSON(http.StatusOK, "OK")
})
a.router.POST("/v1/server/:id", func(c echo.Context) error {
a.router.DELETE("/v1/server/:id", func(c echo.Context) error {
id := util.PathParam(c, "id")
a.logger.Debug().WithFields(log.Fields{
@ -236,6 +236,30 @@ func NewAPI(config APIConfig) (API, error) {
return c.JSON(http.StatusOK, "OK")
})
a.router.PUT("/v1/iam/user/:name/policies", func(c echo.Context) error {
r := client.SetPoliciesRequest{}
if err := util.ShouldBindJSON(c, &r); err != nil {
return httpapi.Err(http.StatusBadRequest, "Invalid JSON", "%s", err)
}
origin := c.Request().Header.Get("X-Cluster-Origin")
if origin == a.id {
return httpapi.Err(http.StatusLoopDetected, "", "breaking circuit")
}
a.logger.Debug().WithField("policies", r.Policies).Log("Set policiesrequest")
err = a.cluster.SetPolicies(origin, r.Name, r.Policies)
if err != nil {
a.logger.Debug().WithError(err).WithField("policies", r.Policies).Log("Unable to set policies")
return httpapi.Err(http.StatusInternalServerError, "unable to add identity", "%s", err)
}
return c.JSON(http.StatusOK, "OK")
})
a.router.DELETE("/v1/iam/user/:name", func(c echo.Context) error {
name := util.PathParam(c, "name")

View File

@ -9,6 +9,7 @@ import (
"time"
httpapi "github.com/datarhei/core/v16/http/api"
iamaccess "github.com/datarhei/core/v16/iam/access"
iamidentity "github.com/datarhei/core/v16/iam/identity"
"github.com/datarhei/core/v16/restream/app"
)
@ -35,6 +36,11 @@ type AddIdentityRequest struct {
Identity iamidentity.User `json:"identity"`
}
type SetPoliciesRequest struct {
Name string `json:"name"`
Policies []iamaccess.Policy `json:"policies"`
}
type APIClient struct {
Address string
Client *http.Client
@ -111,6 +117,17 @@ func (c *APIClient) AddIdentity(origin string, r AddIdentityRequest) error {
return err
}
func (c *APIClient) SetPolicies(origin, name string, r SetPoliciesRequest) error {
data, err := json.Marshal(r)
if err != nil {
return err
}
_, err = c.call(http.MethodPut, "/iam/user/"+name+"/policies", "application/json", bytes.NewReader(data), origin)
return err
}
func (c *APIClient) RemoveIdentity(origin string, name string) error {
_, err := c.call(http.MethodDelete, "/iam/user/"+name, "application/json", nil, origin)

View File

@ -13,9 +13,12 @@ import (
apiclient "github.com/datarhei/core/v16/cluster/client"
"github.com/datarhei/core/v16/cluster/forwarder"
clusteriam "github.com/datarhei/core/v16/cluster/iam"
"github.com/datarhei/core/v16/cluster/proxy"
"github.com/datarhei/core/v16/cluster/raft"
"github.com/datarhei/core/v16/cluster/store"
"github.com/datarhei/core/v16/iam"
iamaccess "github.com/datarhei/core/v16/iam/access"
iamidentity "github.com/datarhei/core/v16/iam/identity"
"github.com/datarhei/core/v16/log"
"github.com/datarhei/core/v16/net"
@ -68,7 +71,10 @@ type Cluster interface {
RemoveProcess(origin, id string) error
UpdateProcess(origin, id string, config *app.Config) error
IAM(superuser iamidentity.User, jwtRealm, jwtSecret string) (iam.IAM, error)
ListIdentities() store.Users
AddIdentity(origin string, identity iamidentity.User) error
SetPolicies(origin, name string, policies []iamaccess.Policy) error
RemoveIdentity(origin string, name string) error
ProxyReader() proxy.ProxyReader
@ -746,6 +752,36 @@ func (c *cluster) UpdateProcess(origin, id string, config *app.Config) error {
return c.applyCommand(cmd)
}
func (c *cluster) IAM(superuser iamidentity.User, jwtRealm, jwtSecret string) (iam.IAM, error) {
policyAdapter, err := clusteriam.NewPolicyAdapter(c.store)
if err != nil {
return nil, fmt.Errorf("cluster policy adapter: %w", err)
}
identityAdapter, err := clusteriam.NewIdentityAdapter(c.store)
if err != nil {
return nil, fmt.Errorf("cluster identitry adapter: %w", err)
}
iam, err := clusteriam.New(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: superuser,
JWTRealm: jwtRealm,
JWTSecret: jwtSecret,
Logger: c.logger.WithField("logname", "iam"),
}, c.store)
if err != nil {
return nil, fmt.Errorf("cluster iam: %w", err)
}
return iam, nil
}
func (c *cluster) ListIdentities() store.Users {
return c.store.UserList()
}
func (c *cluster) AddIdentity(origin string, identity iamidentity.User) error {
if !c.IsRaftLeader() {
return c.forwarder.AddIdentity(origin, identity)
@ -761,13 +797,29 @@ func (c *cluster) AddIdentity(origin string, identity iamidentity.User) error {
return c.applyCommand(cmd)
}
func (c *cluster) SetPolicies(origin, name string, policies []iamaccess.Policy) error {
if !c.IsRaftLeader() {
return c.forwarder.SetPolicies(origin, name, policies)
}
cmd := &store.Command{
Operation: store.OpSetPolicies,
Data: &store.CommandSetPolicies{
Name: name,
Policies: policies,
},
}
return c.applyCommand(cmd)
}
func (c *cluster) RemoveIdentity(origin string, name string) error {
if !c.IsRaftLeader() {
return c.forwarder.RemoveIdentity(origin, name)
}
cmd := &store.Command{
Operation: store.OpAddIdentity,
Operation: store.OpRemoveIdentity,
Data: &store.CommandRemoveIdentity{
Name: name,
},

View File

@ -7,6 +7,7 @@ import (
"time"
apiclient "github.com/datarhei/core/v16/cluster/client"
iamaccess "github.com/datarhei/core/v16/iam/access"
iamidentity "github.com/datarhei/core/v16/iam/identity"
"github.com/datarhei/core/v16/log"
"github.com/datarhei/core/v16/restream/app"
@ -26,6 +27,7 @@ type Forwarder interface {
RemoveProcess(origin, id string) error
AddIdentity(origin string, identity iamidentity.User) error
SetPolicies(origin, name string, policies []iamaccess.Policy) error
RemoveIdentity(origin string, name string) error
}
@ -196,6 +198,23 @@ func (f *forwarder) AddIdentity(origin string, identity iamidentity.User) error
return client.AddIdentity(origin, r)
}
func (f *forwarder) SetPolicies(origin, name string, policies []iamaccess.Policy) error {
if origin == "" {
origin = f.id
}
r := apiclient.SetPoliciesRequest{
Name: name,
Policies: policies,
}
f.lock.RLock()
client := f.client
f.lock.RUnlock()
return client.SetPolicies(origin, name, r)
}
func (f *forwarder) RemoveIdentity(origin string, name string) error {
if origin == "" {
origin = f.id

99
cluster/iam/adapter.go Normal file
View File

@ -0,0 +1,99 @@
package iam
import (
"strings"
"github.com/datarhei/core/v16/cluster/store"
iamaccess "github.com/datarhei/core/v16/iam/access"
iamidentity "github.com/datarhei/core/v16/iam/identity"
"github.com/casbin/casbin/v2/model"
)
type policyAdapter struct {
store store.Store
}
func NewPolicyAdapter(store store.Store) (iamaccess.Adapter, error) {
a := &policyAdapter{
store: store,
}
return a, nil
}
func (a *policyAdapter) LoadPolicy(model model.Model) error {
policies := a.store.PolicyList()
rules := [][]string{}
for _, p := range policies.Policies {
rule := []string{
p.Name,
p.Domain,
p.Resource,
strings.Join(p.Actions, "|"),
}
rules = append(rules, rule)
}
model.ClearPolicy()
model.AddPolicies("p", "p", rules)
return nil
}
func (a *policyAdapter) SavePolicy(model model.Model) error {
return nil
}
func (a *policyAdapter) AddPolicy(sec, ptype string, rule []string) error {
return nil
}
func (a *policyAdapter) AddPolicies(sec string, ptype string, rules [][]string) error {
return nil
}
func (a *policyAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
return nil
}
func (a *policyAdapter) RemovePolicies(sec string, ptype string, rules [][]string) error {
return nil
}
func (a *policyAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
return nil
}
func (a *policyAdapter) AllDomains() []string {
return nil
}
func (a *policyAdapter) HasDomain(name string) bool {
return false
}
type identityAdapter struct {
store store.Store
}
func NewIdentityAdapter(store store.Store) (iamidentity.Adapter, error) {
a := &identityAdapter{
store: store,
}
return a, nil
}
func (a *identityAdapter) LoadIdentities() ([]iamidentity.User, error) {
users := a.store.UserList()
return users.Users, nil
}
func (a *identityAdapter) SaveIdentities([]iamidentity.User) error {
return nil
}

127
cluster/iam/iam.go Normal file
View File

@ -0,0 +1,127 @@
package iam
import (
"github.com/datarhei/core/v16/cluster/store"
"github.com/datarhei/core/v16/iam"
"github.com/datarhei/core/v16/iam/access"
"github.com/datarhei/core/v16/iam/identity"
"github.com/datarhei/core/v16/log"
)
type manager struct {
iam iam.IAM
store store.Store
logger log.Logger
}
func New(config iam.Config, store store.Store) (iam.IAM, error) {
mngr, err := iam.New(config)
if err != nil {
return nil, err
}
m := &manager{
iam: mngr,
store: store,
logger: config.Logger,
}
if m.logger == nil {
m.logger = log.New("")
}
store.OnApply(m.apply)
return m, nil
}
func (m *manager) apply(op store.Operation) {
m.logger.Debug().WithField("operation", string(op)).Log("")
switch op {
case store.OpAddIdentity:
m.ReloadIndentities()
case store.OpRemoveIdentity:
m.ReloadIndentities()
case store.OpSetPolicies:
m.ReloadPolicies()
}
}
func (m *manager) Enforce(name, domain, resource, action string) bool {
return m.iam.Enforce(name, domain, resource, action)
}
func (m *manager) HasDomain(domain string) bool {
return m.iam.HasDomain(domain)
}
func (m *manager) ListDomains() []string {
return m.iam.ListDomains()
}
func (m *manager) HasPolicy(name, domain, resource string, actions []string) bool {
return m.iam.HasPolicy(name, domain, resource, actions)
}
func (m *manager) AddPolicy(name, domain, resource string, actions []string) bool {
return true
}
func (m *manager) RemovePolicy(name, domain, resource string, actions []string) bool {
return true
}
func (m *manager) ListPolicies(name, domain, resource string, actions []string) []access.Policy {
return m.iam.ListPolicies(name, domain, resource, actions)
}
func (m *manager) ReloadPolicies() error {
m.logger.Info().Log("Reloading policies")
return m.iam.ReloadPolicies()
}
func (m *manager) Validators() []string {
return m.iam.Validators()
}
func (m *manager) CreateIdentity(u identity.User) error {
return nil
}
func (m *manager) GetIdentity(name string) (identity.User, error) {
return m.iam.GetIdentity(name)
}
func (m *manager) UpdateIdentity(name string, u identity.User) error {
return nil
}
func (m *manager) DeleteIdentity(name string) error {
return nil
}
func (m *manager) ListIdentities() []identity.User {
return m.iam.ListIdentities()
}
func (m *manager) ReloadIndentities() error {
m.logger.Info().Log("Reloading identities")
return m.iam.ReloadIndentities()
}
func (m *manager) GetVerifier(name string) (identity.Verifier, error) {
return m.iam.GetVerifier(name)
}
func (m *manager) GetVerifierFromAuth0(name string) (identity.Verifier, error) {
return m.iam.GetVerifierFromAuth0(name)
}
func (m *manager) GetDefaultVerifier() identity.Verifier {
return m.iam.GetDefaultVerifier()
}
func (m *manager) CreateJWT(name string) (string, string, error) {
return m.iam.CreateJWT(name)
}
func (m *manager) Close() {}

View File

@ -170,7 +170,7 @@ func (c *cluster) monitorLeadership() {
defer emergencyLeaderLoop.Done()
c.leaderLoop(ch, true)
}(weAreEmergencyLeaderCh)
c.logger.Info().Log("Sluster emergency leadership acquired")
c.logger.Info().Log("Cluster emergency leadership acquired")
c.leaderLock.Lock()
c.isRaftLeader = false

View File

@ -18,10 +18,13 @@ import (
type Store interface {
raft.FSM
OnApply(func(op Operation))
ProcessList() []Process
GetProcess(id string) (Process, error)
UserList() Users
PolicyList() Policies
}
type Process struct {
@ -40,18 +43,19 @@ type Policies struct {
Policies []access.Policy
}
type operation string
type Operation string
const (
OpAddProcess operation = "addProcess"
OpRemoveProcess operation = "removeProcess"
OpUpdateProcess operation = "updateProcess"
OpAddIdentity operation = "addIdentity"
OpRemoveIdentity operation = "removeIdentity"
OpAddProcess Operation = "addProcess"
OpRemoveProcess Operation = "removeProcess"
OpUpdateProcess Operation = "updateProcess"
OpAddIdentity Operation = "addIdentity"
OpRemoveIdentity Operation = "removeIdentity"
OpSetPolicies Operation = "setPolicies"
)
type Command struct {
Operation operation
Operation Operation
Data interface{}
}
@ -76,6 +80,11 @@ type CommandRemoveIdentity struct {
Name string
}
type CommandSetPolicies struct {
Name string
Policies []access.Policy
}
// Implement a FSM
type store struct {
lock sync.RWMutex
@ -86,6 +95,13 @@ type store struct {
Users map[string]identity.User
}
Policies struct {
UpdatedAt time.Time
Policies map[string][]access.Policy
}
callback func(op Operation)
logger log.Logger
}
@ -100,6 +116,7 @@ func NewStore(config Config) (Store, error) {
}
s.Users.Users = map[string]identity.User{}
s.Policies.Policies = map[string][]access.Policy{}
if s.logger == nil {
s.logger = log.New("")
@ -198,22 +215,47 @@ func (s *store) Apply(entry *raft.Log) interface{} {
s.lock.Lock()
delete(s.Users.Users, cmd.Name)
s.Users.UpdatedAt = time.Now()
delete(s.Policies.Policies, cmd.Name)
s.Policies.UpdatedAt = time.Now()
s.lock.Unlock()
case OpSetPolicies:
b, _ := json.Marshal(c.Data)
cmd := CommandSetPolicies{}
json.Unmarshal(b, &cmd)
s.lock.Lock()
delete(s.Policies.Policies, cmd.Name)
s.Policies.Policies[cmd.Name] = cmd.Policies
s.Policies.UpdatedAt = time.Now()
s.lock.Unlock()
default:
s.logger.Warn().WithField("operation", c.Operation).Log("Unknown operation")
}
s.lock.RLock()
if s.callback != nil {
s.callback(c.Operation)
}
s.lock.RUnlock()
s.lock.RLock()
s.logger.Debug().WithField("processes", s.Process).Log("")
s.lock.RUnlock()
return nil
}
func (s *store) OnApply(fn func(op Operation)) {
s.lock.Lock()
defer s.lock.Unlock()
s.callback = fn
}
func (s *store) Snapshot() (raft.FSMSnapshot, error) {
s.logger.Debug().Log("Snapshot request")
s.lock.Lock()
defer s.lock.Unlock()
s.lock.RLock()
defer s.lock.RUnlock()
data, err := json.Marshal(s)
if err != nil {
@ -287,6 +329,21 @@ func (s *store) UserList() Users {
return u
}
func (s *store) PolicyList() Policies {
s.lock.RLock()
defer s.lock.RUnlock()
p := Policies{
UpdatedAt: s.Policies.UpdatedAt,
}
for _, policies := range s.Policies.Policies {
p.Policies = append(p.Policies, policies...)
}
return p
}
type fsmSnapshot struct {
data []byte
}

View File

@ -1,4 +1,5 @@
// Code generated by swaggo/swag. DO NOT EDIT
// Code generated by swaggo/swag. DO NOT EDIT.
package docs
import "github.com/swaggo/swag"
@ -159,6 +160,170 @@ const docTemplate = `{
}
}
},
"/api/v3/cluster/iam/user": {
"post": {
"security": [
{
"ApiKeyAuth": []
}
],
"description": "Add a new identity",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"v16.?.?"
],
"summary": "Add a new identiy",
"operationId": "cluster-3-add-identity",
"parameters": [
{
"description": "Identity",
"name": "config",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/api.IAMUser"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/api.IAMUser"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/api.Error"
}
}
}
}
},
"/api/v3/cluster/iam/user/{name}": {
"delete": {
"security": [
{
"ApiKeyAuth": []
}
],
"description": "Delete an identity by its name",
"produces": [
"application/json"
],
"tags": [
"v16.?.?"
],
"summary": "Delete an identity by its name",
"operationId": "cluster-3-delete-identity",
"parameters": [
{
"type": "string",
"description": "Identity name",
"name": "name",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/api.Error"
}
}
}
}
},
"/api/v3/cluster/iam/user/{name}/policy": {
"put": {
"security": [
{
"ApiKeyAuth": []
}
],
"description": "Replace policies of an user",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"v16.?.?"
],
"summary": "Replace policies of an user",
"operationId": "cluster-3-update-user-policies",
"parameters": [
{
"type": "string",
"description": "Username",
"name": "name",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Domain of the acting user",
"name": "domain",
"in": "query"
},
{
"description": "Policy definitions",
"name": "user",
"in": "body",
"required": true,
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/api.IAMPolicy"
}
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/api.IAMPolicy"
}
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/api.Error"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/api.Error"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/api.Error"
}
}
}
}
},
"/api/v3/cluster/node": {
"get": {
"security": [
@ -1075,7 +1240,7 @@ const docTemplate = `{
"v16.?.?"
],
"summary": "Replace policies of an user",
"operationId": "iam-3-update-user",
"operationId": "iam-3-update-user-policies",
"parameters": [
{
"type": "string",
@ -1370,15 +1535,27 @@ const docTemplate = `{
},
{
"type": "string",
"description": "Glob pattern for process IDs. If empty all IDs will be returned. Intersected with results from refpattern.",
"description": "Glob pattern for process IDs. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "idpattern",
"in": "query"
},
{
"type": "string",
"description": "Glob pattern for process references. If empty all IDs will be returned. Intersected with results from idpattern.",
"description": "Glob pattern for process references. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "refpattern",
"in": "query"
},
{
"type": "string",
"description": "Glob pattern for process owners. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "ownerpattern",
"in": "query"
},
{
"type": "string",
"description": "Glob pattern for process domain. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "domainpattern",
"in": "query"
}
],
"responses": {
@ -3177,6 +3354,7 @@ const docTemplate = `{
"type": "boolean"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -3255,6 +3433,7 @@ const docTemplate = `{
"type": "string"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -3326,15 +3505,19 @@ const docTemplate = `{
"type": "object",
"properties": {
"auth": {
"description": "Deprecated, use IAM",
"type": "object",
"properties": {
"enable": {
"description": "Deprecated, use IAM",
"type": "boolean"
},
"password": {
"description": "Deprecated, use IAM",
"type": "string"
},
"username": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -3536,16 +3719,13 @@ const docTemplate = `{
"$ref": "#/definitions/api.IAMUserAuthAPIAuth0"
},
"userpass": {
"$ref": "#/definitions/api.IAMUserAuthPassword"
"type": "string"
}
}
},
"api.IAMUserAuthAPIAuth0": {
"type": "object",
"properties": {
"enable": {
"type": "boolean"
},
"tenant": {
"$ref": "#/definitions/api.IAMAuth0Tenant"
},
@ -3554,22 +3734,14 @@ const docTemplate = `{
}
}
},
"api.IAMUserAuthPassword": {
"type": "object",
"properties": {
"enable": {
"type": "boolean"
},
"password": {
"type": "string"
}
}
},
"api.IAMUserAuthServices": {
"type": "object",
"properties": {
"basic": {
"$ref": "#/definitions/api.IAMUserAuthPassword"
"type": "array",
"items": {
"type": "string"
}
},
"token": {
"type": "array",
@ -3902,7 +4074,7 @@ const docTemplate = `{
"autostart": {
"type": "boolean"
},
"group": {
"domain": {
"type": "string"
},
"id": {
@ -3929,6 +4101,9 @@ const docTemplate = `{
"$ref": "#/definitions/api.ProcessConfigIO"
}
},
"owner": {
"type": "string"
},
"reconnect": {
"type": "boolean"
},
@ -5093,6 +5268,7 @@ const docTemplate = `{
"type": "boolean"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -5171,6 +5347,7 @@ const docTemplate = `{
"type": "string"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -5242,15 +5419,19 @@ const docTemplate = `{
"type": "object",
"properties": {
"auth": {
"description": "Deprecated, use IAM",
"type": "object",
"properties": {
"enable": {
"description": "Deprecated, use IAM",
"type": "boolean"
},
"password": {
"description": "Deprecated, use IAM",
"type": "string"
},
"username": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -5622,7 +5803,12 @@ const docTemplate = `{
"type": "string"
},
"auth": {
"$ref": "#/definitions/value.S3StorageAuth"
"description": "Deprecated, use IAM",
"allOf": [
{
"$ref": "#/definitions/value.S3StorageAuth"
}
]
},
"bucket": {
"type": "string"
@ -5651,12 +5837,15 @@ const docTemplate = `{
"type": "object",
"properties": {
"enable": {
"description": "Deprecated, use IAM",
"type": "boolean"
},
"password": {
"description": "Deprecated, use IAM",
"type": "string"
},
"username": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -5694,6 +5883,8 @@ var SwaggerInfo = &swag.Spec{
Description: "Expose REST API for the datarhei Core",
InfoInstanceName: "swagger",
SwaggerTemplate: docTemplate,
LeftDelim: "{{",
RightDelim: "}}",
}
func init() {

View File

@ -152,6 +152,170 @@
}
}
},
"/api/v3/cluster/iam/user": {
"post": {
"security": [
{
"ApiKeyAuth": []
}
],
"description": "Add a new identity",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"v16.?.?"
],
"summary": "Add a new identiy",
"operationId": "cluster-3-add-identity",
"parameters": [
{
"description": "Identity",
"name": "config",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/api.IAMUser"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/api.IAMUser"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/api.Error"
}
}
}
}
},
"/api/v3/cluster/iam/user/{name}": {
"delete": {
"security": [
{
"ApiKeyAuth": []
}
],
"description": "Delete an identity by its name",
"produces": [
"application/json"
],
"tags": [
"v16.?.?"
],
"summary": "Delete an identity by its name",
"operationId": "cluster-3-delete-identity",
"parameters": [
{
"type": "string",
"description": "Identity name",
"name": "name",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/api.Error"
}
}
}
}
},
"/api/v3/cluster/iam/user/{name}/policy": {
"put": {
"security": [
{
"ApiKeyAuth": []
}
],
"description": "Replace policies of an user",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"v16.?.?"
],
"summary": "Replace policies of an user",
"operationId": "cluster-3-update-user-policies",
"parameters": [
{
"type": "string",
"description": "Username",
"name": "name",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Domain of the acting user",
"name": "domain",
"in": "query"
},
{
"description": "Policy definitions",
"name": "user",
"in": "body",
"required": true,
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/api.IAMPolicy"
}
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/api.IAMPolicy"
}
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/api.Error"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/api.Error"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/api.Error"
}
}
}
}
},
"/api/v3/cluster/node": {
"get": {
"security": [
@ -1068,7 +1232,7 @@
"v16.?.?"
],
"summary": "Replace policies of an user",
"operationId": "iam-3-update-user",
"operationId": "iam-3-update-user-policies",
"parameters": [
{
"type": "string",
@ -1363,15 +1527,27 @@
},
{
"type": "string",
"description": "Glob pattern for process IDs. If empty all IDs will be returned. Intersected with results from refpattern.",
"description": "Glob pattern for process IDs. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "idpattern",
"in": "query"
},
{
"type": "string",
"description": "Glob pattern for process references. If empty all IDs will be returned. Intersected with results from idpattern.",
"description": "Glob pattern for process references. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "refpattern",
"in": "query"
},
{
"type": "string",
"description": "Glob pattern for process owners. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "ownerpattern",
"in": "query"
},
{
"type": "string",
"description": "Glob pattern for process domain. If empty all IDs will be returned. Intersected with results from other pattern matches.",
"name": "domainpattern",
"in": "query"
}
],
"responses": {
@ -3170,6 +3346,7 @@
"type": "boolean"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -3248,6 +3425,7 @@
"type": "string"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -3319,15 +3497,19 @@
"type": "object",
"properties": {
"auth": {
"description": "Deprecated, use IAM",
"type": "object",
"properties": {
"enable": {
"description": "Deprecated, use IAM",
"type": "boolean"
},
"password": {
"description": "Deprecated, use IAM",
"type": "string"
},
"username": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -3529,16 +3711,13 @@
"$ref": "#/definitions/api.IAMUserAuthAPIAuth0"
},
"userpass": {
"$ref": "#/definitions/api.IAMUserAuthPassword"
"type": "string"
}
}
},
"api.IAMUserAuthAPIAuth0": {
"type": "object",
"properties": {
"enable": {
"type": "boolean"
},
"tenant": {
"$ref": "#/definitions/api.IAMAuth0Tenant"
},
@ -3547,22 +3726,14 @@
}
}
},
"api.IAMUserAuthPassword": {
"type": "object",
"properties": {
"enable": {
"type": "boolean"
},
"password": {
"type": "string"
}
}
},
"api.IAMUserAuthServices": {
"type": "object",
"properties": {
"basic": {
"$ref": "#/definitions/api.IAMUserAuthPassword"
"type": "array",
"items": {
"type": "string"
}
},
"token": {
"type": "array",
@ -3895,7 +4066,7 @@
"autostart": {
"type": "boolean"
},
"group": {
"domain": {
"type": "string"
},
"id": {
@ -3922,6 +4093,9 @@
"$ref": "#/definitions/api.ProcessConfigIO"
}
},
"owner": {
"type": "string"
},
"reconnect": {
"type": "boolean"
},
@ -5086,6 +5260,7 @@
"type": "boolean"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -5164,6 +5339,7 @@
"type": "string"
},
"token": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -5235,15 +5411,19 @@
"type": "object",
"properties": {
"auth": {
"description": "Deprecated, use IAM",
"type": "object",
"properties": {
"enable": {
"description": "Deprecated, use IAM",
"type": "boolean"
},
"password": {
"description": "Deprecated, use IAM",
"type": "string"
},
"username": {
"description": "Deprecated, use IAM",
"type": "string"
}
}
@ -5615,7 +5795,12 @@
"type": "string"
},
"auth": {
"$ref": "#/definitions/value.S3StorageAuth"
"description": "Deprecated, use IAM",
"allOf": [
{
"$ref": "#/definitions/value.S3StorageAuth"
}
]
},
"bucket": {
"type": "string"
@ -5644,12 +5829,15 @@
"type": "object",
"properties": {
"enable": {
"description": "Deprecated, use IAM",
"type": "boolean"
},
"password": {
"description": "Deprecated, use IAM",
"type": "string"
},
"username": {
"description": "Deprecated, use IAM",
"type": "string"
}
}

View File

@ -403,6 +403,7 @@ definitions:
enable_tls:
type: boolean
token:
description: Deprecated, use IAM
type: string
type: object
service:
@ -455,6 +456,7 @@ definitions:
passphrase:
type: string
token:
description: Deprecated, use IAM
type: string
type: object
storage:
@ -502,12 +504,16 @@ definitions:
memory:
properties:
auth:
description: Deprecated, use IAM
properties:
enable:
description: Deprecated, use IAM
type: boolean
password:
description: Deprecated, use IAM
type: string
username:
description: Deprecated, use IAM
type: string
type: object
max_size_mbytes:
@ -640,28 +646,21 @@ definitions:
auth0:
$ref: '#/definitions/api.IAMUserAuthAPIAuth0'
userpass:
$ref: '#/definitions/api.IAMUserAuthPassword'
type: string
type: object
api.IAMUserAuthAPIAuth0:
properties:
enable:
type: boolean
tenant:
$ref: '#/definitions/api.IAMAuth0Tenant'
user:
type: string
type: object
api.IAMUserAuthPassword:
properties:
enable:
type: boolean
password:
type: string
type: object
api.IAMUserAuthServices:
properties:
basic:
$ref: '#/definitions/api.IAMUserAuthPassword'
items:
type: string
type: array
token:
items:
type: string
@ -885,7 +884,7 @@ definitions:
properties:
autostart:
type: boolean
group:
domain:
type: string
id:
type: string
@ -903,6 +902,8 @@ definitions:
items:
$ref: '#/definitions/api.ProcessConfigIO'
type: array
owner:
type: string
reconnect:
type: boolean
reconnect_delay_seconds:
@ -1759,6 +1760,7 @@ definitions:
enable_tls:
type: boolean
token:
description: Deprecated, use IAM
type: string
type: object
service:
@ -1811,6 +1813,7 @@ definitions:
passphrase:
type: string
token:
description: Deprecated, use IAM
type: string
type: object
storage:
@ -1858,12 +1861,16 @@ definitions:
memory:
properties:
auth:
description: Deprecated, use IAM
properties:
enable:
description: Deprecated, use IAM
type: boolean
password:
description: Deprecated, use IAM
type: string
username:
description: Deprecated, use IAM
type: string
type: object
max_size_mbytes:
@ -2105,7 +2112,9 @@ definitions:
access_key_id:
type: string
auth:
$ref: '#/definitions/value.S3StorageAuth'
allOf:
- $ref: '#/definitions/value.S3StorageAuth'
description: Deprecated, use IAM
bucket:
type: string
endpoint:
@ -2124,10 +2133,13 @@ definitions:
value.S3StorageAuth:
properties:
enable:
description: Deprecated, use IAM
type: boolean
password:
description: Deprecated, use IAM
type: string
username:
description: Deprecated, use IAM
type: string
type: object
info:
@ -2228,6 +2240,111 @@ paths:
summary: List of nodes in the cluster
tags:
- v16.?.?
/api/v3/cluster/iam/user:
post:
consumes:
- application/json
description: Add a new identity
operationId: cluster-3-add-identity
parameters:
- description: Identity
in: body
name: config
required: true
schema:
$ref: '#/definitions/api.IAMUser'
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/api.IAMUser'
"400":
description: Bad Request
schema:
$ref: '#/definitions/api.Error'
security:
- ApiKeyAuth: []
summary: Add a new identiy
tags:
- v16.?.?
/api/v3/cluster/iam/user/{name}:
delete:
description: Delete an identity by its name
operationId: cluster-3-delete-identity
parameters:
- description: Identity name
in: path
name: name
required: true
type: string
produces:
- application/json
responses:
"200":
description: OK
schema:
type: string
"404":
description: Not Found
schema:
$ref: '#/definitions/api.Error'
security:
- ApiKeyAuth: []
summary: Delete an identity by its name
tags:
- v16.?.?
/api/v3/cluster/iam/user/{name}/policy:
put:
consumes:
- application/json
description: Replace policies of an user
operationId: cluster-3-update-user-policies
parameters:
- description: Username
in: path
name: name
required: true
type: string
- description: Domain of the acting user
in: query
name: domain
type: string
- description: Policy definitions
in: body
name: user
required: true
schema:
items:
$ref: '#/definitions/api.IAMPolicy'
type: array
produces:
- application/json
responses:
"200":
description: OK
schema:
items:
$ref: '#/definitions/api.IAMPolicy'
type: array
"400":
description: Bad Request
schema:
$ref: '#/definitions/api.Error'
"404":
description: Not Found
schema:
$ref: '#/definitions/api.Error'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/api.Error'
security:
- ApiKeyAuth: []
summary: Replace policies of an user
tags:
- v16.?.?
/api/v3/cluster/node:
get:
description: List of proxy nodes in the cluster
@ -2811,7 +2928,7 @@ paths:
consumes:
- application/json
description: Replace policies of an user
operationId: iam-3-update-user
operationId: iam-3-update-user-policies
parameters:
- description: Username
in: path
@ -3009,15 +3126,25 @@ paths:
name: id
type: string
- description: Glob pattern for process IDs. If empty all IDs will be returned.
Intersected with results from refpattern.
Intersected with results from other pattern matches.
in: query
name: idpattern
type: string
- description: Glob pattern for process references. If empty all IDs will be
returned. Intersected with results from idpattern.
returned. Intersected with results from other pattern matches.
in: query
name: refpattern
type: string
- description: Glob pattern for process owners. If empty all IDs will be returned.
Intersected with results from other pattern matches.
in: query
name: ownerpattern
type: string
- description: Glob pattern for process domain. If empty all IDs will be returned.
Intersected with results from other pattern matches.
in: query
name: domainpattern
type: string
produces:
- application/json
responses:

95
go.mod
View File

@ -3,38 +3,38 @@ module github.com/datarhei/core/v16
go 1.18
require (
github.com/99designs/gqlgen v0.17.20
github.com/99designs/gqlgen v0.17.31
github.com/Masterminds/semver/v3 v3.2.1
github.com/atrox/haikunatorgo/v2 v2.0.1
github.com/caddyserver/certmagic v0.17.2
github.com/casbin/casbin/v2 v2.60.0
github.com/casbin/casbin/v2 v2.69.1
github.com/datarhei/core-client-go/v16 v16.11.1-0.20230512155342-18a7ac72df3a
github.com/datarhei/gosrt v0.3.1
github.com/datarhei/gosrt v0.4.1
github.com/datarhei/joy4 v0.0.0-20230505074825-fde05957445a
github.com/go-playground/validator/v10 v10.11.1
github.com/go-playground/validator/v10 v10.14.0
github.com/gobwas/glob v0.2.3
github.com/golang-jwt/jwt/v4 v4.4.3
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/uuid v1.3.0
github.com/hashicorp/go-hclog v1.5.0
github.com/hashicorp/raft v1.4.0
github.com/hashicorp/raft v1.5.0
github.com/hashicorp/raft-boltdb/v2 v2.2.2
github.com/invopop/jsonschema v0.4.0
github.com/joho/godotenv v1.4.0
github.com/labstack/echo/v4 v4.9.1
github.com/joho/godotenv v1.5.1
github.com/labstack/echo/v4 v4.10.2
github.com/lithammer/shortuuid/v4 v4.0.0
github.com/mattn/go-isatty v0.0.18
github.com/minio/minio-go/v7 v7.0.47
github.com/mattn/go-isatty v0.0.19
github.com/minio/minio-go/v7 v7.0.55
github.com/prep/average v0.0.0-20200506183628-d26c465f48c3
github.com/prometheus/client_golang v1.14.0
github.com/shirou/gopsutil/v3 v3.23.3
github.com/prometheus/client_golang v1.15.1
github.com/shirou/gopsutil/v3 v3.23.4
github.com/stretchr/testify v1.8.2
github.com/swaggo/echo-swagger v1.3.5
github.com/swaggo/swag v1.8.7
github.com/swaggo/echo-swagger v1.4.0
github.com/swaggo/swag v1.16.1
github.com/vektah/gqlparser/v2 v2.5.1
github.com/xeipuuv/gojsonschema v1.2.0
go.etcd.io/bbolt v1.3.7
go.uber.org/zap v1.24.0
golang.org/x/mod v0.7.0
golang.org/x/mod v0.10.0
)
//replace github.com/datarhei/core-client-go/v16 => ../core-client-go
@ -48,73 +48,74 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/boltdb/bolt v1.3.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fatih/color v1.15.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.20.0 // indirect
github.com/go-openapi/spec v0.20.8 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/spec v0.20.9 // indirect
github.com/go-openapi/swag v0.22.3 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
github.com/hashicorp/go-msgpack v0.5.5 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.2 // indirect
github.com/iancoleman/orderedmap v0.2.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.15.15 // indirect
github.com/klauspost/cpuid/v2 v2.2.3 // indirect
github.com/klauspost/compress v1.16.5 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/labstack/gommon v0.4.0 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/libdns/libdns v0.2.1 // indirect
github.com/lufia/plan9stats v0.0.0-20220913051719-115f729f3c8c // indirect
github.com/lufia/plan9stats v0.0.0-20230326075908-cb1d2100619a // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mholt/acmez v1.0.4 // indirect
github.com/miekg/dns v1.1.50 // indirect
github.com/mholt/acmez v1.1.1 // indirect
github.com/miekg/dns v1.1.54 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.0 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b // indirect
github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.39.0 // indirect
github.com/prometheus/procfs v0.9.0 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect
github.com/rs/xid v1.4.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.1 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/shoenig/go-m1cpu v0.1.4 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/swaggo/files v0.0.0-20220728132757-551d4a08d97a // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.2 // indirect
github.com/swaggo/files/v2 v2.0.0 // indirect
github.com/tklauser/go-sysconf v0.3.11 // indirect
github.com/tklauser/numcpus v0.6.0 // indirect
github.com/urfave/cli/v2 v2.8.1 // indirect
github.com/urfave/cli/v2 v2.24.4 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect
go.uber.org/atomic v1.10.0 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/goleak v1.1.12 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.5.0 // indirect
golang.org/x/net v0.5.0 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/text v0.7.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.4.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
golang.org/x/tools v0.9.1 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

283
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
package api
import (
"fmt"
"net/http"
"sort"
"strings"
@ -11,6 +12,9 @@ import (
"github.com/datarhei/core/v16/encoding/json"
"github.com/datarhei/core/v16/http/api"
"github.com/datarhei/core/v16/http/handler/util"
"github.com/datarhei/core/v16/iam"
"github.com/datarhei/core/v16/iam/access"
"github.com/datarhei/core/v16/iam/identity"
"github.com/datarhei/core/v16/restream"
"github.com/labstack/echo/v4"
@ -21,14 +25,26 @@ import (
type ClusterHandler struct {
cluster cluster.Cluster
proxy proxy.ProxyReader
iam iam.IAM
}
// NewCluster return a new ClusterHandler type. You have to provide a cluster.
func NewCluster(cluster cluster.Cluster) *ClusterHandler {
return &ClusterHandler{
func NewCluster(cluster cluster.Cluster, iam iam.IAM) (*ClusterHandler, error) {
h := &ClusterHandler{
cluster: cluster,
proxy: cluster.ProxyReader(),
iam: iam,
}
if h.cluster == nil {
return nil, fmt.Errorf("no cluster provided")
}
if h.iam == nil {
return nil, fmt.Errorf("no IAM provided")
}
return h, nil
}
// GetNodes returns the list of proxy nodes in the cluster
@ -397,15 +413,105 @@ func (h *ClusterHandler) AddIdentity(c echo.Context) error {
return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err)
}
identity, _ := user.Unmarshal()
identity, policies := user.Unmarshal()
if err := h.cluster.AddIdentity("", identity); err != nil {
return api.Err(http.StatusBadRequest, "Invalid identity", "%s", err.Error())
}
if err := h.cluster.SetPolicies("", identity.Name, policies); err != nil {
return api.Err(http.StatusBadRequest, "Invalid policies", "%s", err.Error())
}
return c.JSON(http.StatusOK, user)
}
// UpdateIdentityPolicies replaces existing user policies
// @Summary Replace policies of an user
// @Description Replace policies of an user
// @Tags v16.?.?
// @ID cluster-3-update-user-policies
// @Accept json
// @Produce json
// @Param name path string true "Username"
// @Param domain query string false "Domain of the acting user"
// @Param user body []api.IAMPolicy true "Policy definitions"
// @Success 200 {array} api.IAMPolicy
// @Failure 400 {object} api.Error
// @Failure 404 {object} api.Error
// @Failure 500 {object} api.Error
// @Security ApiKeyAuth
// @Router /api/v3/cluster/iam/user/{name}/policy [put]
func (h *ClusterHandler) UpdateIdentityPolicies(c echo.Context) error {
ctxuser := util.DefaultContext(c, "user", "")
superuser := util.DefaultContext(c, "superuser", false)
domain := util.DefaultQuery(c, "domain", "$none")
name := util.PathParam(c, "name")
if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "write") {
return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to modify this user")
}
var iamuser identity.User
var err error
if name != "$anon" {
iamuser, err = h.iam.GetIdentity(name)
if err != nil {
return api.Err(http.StatusNotFound, "Not found", "%s", err)
}
} else {
iamuser = identity.User{
Name: "$anon",
}
}
policies := []api.IAMPolicy{}
if err := util.ShouldBindJSONValidation(c, &policies, false); err != nil {
return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err)
}
for _, p := range policies {
err := c.Validate(p)
if err != nil {
return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err)
}
}
accessPolicies := []access.Policy{}
for _, p := range policies {
if !h.iam.Enforce(ctxuser, p.Domain, "iam:"+iamuser.Name, "write") {
return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to write policy: %v", p)
}
accessPolicies = append(accessPolicies, access.Policy{
Name: name,
Domain: p.Domain,
Resource: p.Resource,
Actions: p.Actions,
})
}
if !superuser && iamuser.Superuser {
return api.Err(http.StatusForbidden, "Forbidden", "Only superusers can modify superusers")
}
err = h.cluster.SetPolicies("", name, accessPolicies)
if err != nil {
return api.Err(http.StatusInternalServerError, "", "set policies: %w", err)
}
return c.JSON(http.StatusOK, policies)
}
func (h *ClusterHandler) ListIdentities(c echo.Context) error {
identities := h.cluster.ListIdentities()
return c.JSON(http.StatusOK, identities)
}
// Delete deletes the identity with the given name from the cluster
// @Summary Delete an identity by its name
// @Description Delete an identity by its name

View File

@ -203,7 +203,7 @@ func (h *IAMHandler) UpdateUser(c echo.Context) error {
// @Summary Replace policies of an user
// @Description Replace policies of an user
// @Tags v16.?.?
// @ID iam-3-update-user
// @ID iam-3-update-user-policies
// @Accept json
// @Produce json
// @Param name path string true "Username"
@ -241,10 +241,17 @@ func (h *IAMHandler) UpdateUserPolicies(c echo.Context) error {
policies := []api.IAMPolicy{}
if err := util.ShouldBindJSON(c, &policies); err != nil {
if err := util.ShouldBindJSONValidation(c, &policies, false); err != nil {
return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err)
}
for _, p := range policies {
err := c.Validate(p)
if err != nil {
return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err)
}
}
for _, p := range policies {
if !h.iam.Enforce(ctxuser, p.Domain, "iam:"+iamuser.Name, "write") {
return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to write policy: %v", p)

View File

@ -45,7 +45,7 @@ func getDummyRestreamHandler() (*RestreamHandler, error) {
return nil, err
}
iam, err := iam.NewIAM(iam.Config{
iam, err := iam.New(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: identity.User{

View File

@ -125,6 +125,10 @@ func NewWithConfig(config Config) echo.MiddlewareFunc {
resource := c.Request().URL.Path
var domain string
if resource == "/ping" {
return next(c)
}
if resource == "/api" || strings.HasPrefix(resource, "/api/") {
if resource == "/api/login" {
identity, err = mw.findIdentityFromUserpass(c)
@ -421,7 +425,7 @@ func (m *iammiddleware) findIdentityFromAuth0(c echo.Context) (iamidentity.Verif
}
}
identity, err := m.iam.GetVerfierFromAuth0(subject)
identity, err := m.iam.GetVerifierFromAuth0(subject)
if err != nil {
m.logger.Debug().WithFields(log.Fields{
"path": c.Request().URL.Path,

View File

@ -40,7 +40,7 @@ func getIAM() (iam.IAM, error) {
return nil, err
}
i, err := iam.NewIAM(iam.Config{
i, err := iam.New(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: iamidentity.User{

View File

@ -65,7 +65,7 @@ func DummyRestreamer(pathPrefix string) (restream.Restreamer, error) {
return nil, err
}
iam, err := iam.NewIAM(iam.Config{
iam, err := iam.New(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: iamidentity.User{

View File

@ -321,7 +321,11 @@ func NewServer(config Config) (Server, error) {
})
if config.Cluster != nil {
s.v3handler.cluster = api.NewCluster(config.Cluster)
handler, err := api.NewCluster(config.Cluster, config.IAM)
if err != nil {
return nil, fmt.Errorf("cluster handler: %w", err)
}
s.v3handler.cluster = handler
}
if middleware, err := mwcors.NewWithConfig(mwcors.Config{
@ -653,6 +657,7 @@ func (s *server) setRoutesV3(v3 *echo.Group) {
if s.v3handler.cluster != nil {
v3.GET("/cluster", s.v3handler.cluster.About)
v3.GET("/cluster/process", s.v3handler.cluster.ListProcesses)
v3.GET("/cluster/iam/user", s.v3handler.cluster.ListIdentities)
v3.GET("/cluster/node", s.v3handler.cluster.GetNodes)
v3.GET("/cluster/node/process", s.v3handler.cluster.ListNodeProcesses)
@ -666,6 +671,7 @@ func (s *server) setRoutesV3(v3 *echo.Group) {
v3.DELETE("/cluster/process/:id", s.v3handler.cluster.DeleteProcess)
v3.POST("/cluster/iam/user", s.v3handler.cluster.AddIdentity)
v3.PUT("/cluster/iam/user/:name/policies", s.v3handler.cluster.UpdateIdentityPolicies)
v3.DELETE("/cluster/iam/user/:name", s.v3handler.cluster.RemoveIdentity)
}
}

View File

@ -32,7 +32,7 @@ type IAM interface {
ReloadIndentities() error
GetVerifier(name string) (identity.Verifier, error)
GetVerfierFromAuth0(name string) (identity.Verifier, error)
GetVerifierFromAuth0(name string) (identity.Verifier, error)
GetDefaultVerifier() identity.Verifier
CreateJWT(name string) (string, string, error)
@ -56,7 +56,7 @@ type Config struct {
Logger log.Logger
}
func NewIAM(config Config) (IAM, error) {
func New(config Config) (IAM, error) {
im, err := identity.New(identity.Config{
Adapter: config.IdentityAdapter,
Superuser: config.Superuser,
@ -168,7 +168,7 @@ func (i *iam) GetVerifier(name string) (identity.Verifier, error) {
return i.im.GetVerifier(name)
}
func (i *iam) GetVerfierFromAuth0(name string) (identity.Verifier, error) {
func (i *iam) GetVerifierFromAuth0(name string) (identity.Verifier, error) {
return i.im.GetVerifierFromAuth0(name)
}

View File

@ -46,7 +46,7 @@ type UserAuthServices struct {
func (u *User) validate() error {
if len(u.Name) == 0 {
return fmt.Errorf("the name is required")
return fmt.Errorf("a name is required")
}
chars := `A-Za-z0-9:_-`
@ -508,6 +508,10 @@ func (im *identityManager) Reload() error {
continue
}
if err := u.validate(); err != nil {
continue
}
identity, err := im.create(u)
if err != nil {
continue

View File

@ -51,7 +51,7 @@ func getDummyRestreamer(portrange net.Portranger, validatorIn, validatorOut ffmp
return nil, err
}
iam, err := iam.NewIAM(iam.Config{
iam, err := iam.New(iam.Config{
PolicyAdapter: policyAdapter,
IdentityAdapter: identityAdapter,
Superuser: iamidentity.User{

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,8 @@ Still not convinced enough to use **gqlgen**? Compare **gqlgen** with other Go g
2. Add `github.com/99designs/gqlgen` to your [project's tools.go](https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module)
printf '// +build tools\npackage tools\nimport _ "github.com/99designs/gqlgen"' | gofmt > tools.go
printf '// +build tools\npackage tools\nimport (_ "github.com/99designs/gqlgen"\n _ "github.com/99designs/gqlgen/graphql/introspection")' | gofmt > tools.go
go mod tidy
3. Initialise gqlgen config and generate models
@ -113,7 +114,7 @@ directive @goModel(model: String, models: [String!]) on OBJECT
| INTERFACE
| UNION
directive @goField(forceResolver: Boolean, name: String) on INPUT_FIELD_DEFINITION
directive @goField(forceResolver: Boolean, name: String, omittable: Boolean) on INPUT_FIELD_DEFINITION
| FIELD_DEFINITION
type User @goModel(model: "github.com/you/pkg/model.User") {

View File

@ -6,9 +6,10 @@ Assuming the next version is $NEW_VERSION=v0.16.0 or something like that.
./bin/release $NEW_VERSION
```
2. git-chglog -o CHANGELOG.md
3. git commit and push the CHANGELOG.md
4. Go to https://github.com/99designs/gqlgen/releases and draft new release, autogenerate the release notes, and Create a discussion for this release
5. Comment on the release discussion with any really important notes (breaking changes)
3. go generate ./...; cd _examples; go generate ./...; cd ..
4. git commit and push the CHANGELOG.md
5. Go to https://github.com/99designs/gqlgen/releases and draft new release, autogenerate the release notes, and Create a discussion for this release
6. Comment on the release discussion with any really important notes (breaking changes)
I used https://github.com/git-chglog/git-chglog to automate the changelog maintenance process for now. We could just as easily use go releaser to make the whole thing automated.

View File

@ -36,5 +36,4 @@ npm install
will write the schema to `integration/schema-fetched.graphql`, compare that with `schema-expected.graphql`
CI will run this and fail the build if the two files dont match.
CI will run this and fail the build if the two files don't match.

View File

@ -83,7 +83,11 @@ func Generate(cfg *config.Config, option ...Option) error {
}
}
// Merge again now that the generated models have been injected into the typemap
data, err := codegen.BuildData(cfg)
data_plugins := make([]interface{}, len(plugins))
for index := range plugins {
data_plugins[index] = plugins[index]
}
data, err := codegen.BuildData(cfg, data_plugins...)
if err != nil {
return fmt.Errorf("merging type systems failed: %w", err)
}

View File

@ -20,6 +20,7 @@ type Binder struct {
pkgs *code.Packages
schema *ast.Schema
cfg *Config
tctx *types.Context
References []*TypeReference
SawInvalid bool
objectCache map[string]map[string]types.Object
@ -81,6 +82,14 @@ func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
return obj.Type(), nil
}
func (b *Binder) InstantiateType(orig types.Type, targs []types.Type) (types.Type, error) {
if b.tctx == nil {
b.tctx = types.NewContext()
}
return types.Instantiate(b.tctx, orig, targs, false)
}
var (
MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
InterfaceType = types.NewInterfaceType(nil, nil)
@ -183,15 +192,17 @@ func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
// TypeReference is used by args and field types. The Definition can refer to both input and output types.
type TypeReference struct {
Definition *ast.Definition
GQL *ast.Type
GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target.
Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
CastType types.Type // Before calling marshalling functions cast from/to this base type
Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
Definition *ast.Definition
GQL *ast.Type
GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target.
Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
CastType types.Type // Before calling marshalling functions cast from/to this base type
Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
IsOmittable bool // Is the type wrapped with Omittable
IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
PointersInUmarshalInput bool // Inverse values and pointers in return.
}
func (ref *TypeReference) Elem() *TypeReference {
@ -317,7 +328,35 @@ func isIntf(t types.Type) bool {
return ok
}
func unwrapOmittable(t types.Type) (types.Type, bool) {
if t == nil {
return t, false
}
named, ok := t.(*types.Named)
if !ok {
return t, false
}
if named.Origin().String() != "github.com/99designs/gqlgen/graphql.Omittable[T any]" {
return t, false
}
return named.TypeArgs().At(0), true
}
func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
if innerType, ok := unwrapOmittable(bindTarget); ok {
if schemaType.NonNull {
return nil, fmt.Errorf("%s is wrapped with Omittable but non-null", schemaType.Name())
}
ref, err := b.TypeReference(schemaType, innerType)
if err != nil {
return nil, err
}
ref.IsOmittable = true
return ref, err
}
if !isValid(bindTarget) {
b.SawInvalid = true
return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
@ -412,6 +451,8 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret
ref.GO = bindTarget
}
ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput
return ref, nil
}

View File

@ -3,6 +3,7 @@ package config
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
@ -27,8 +28,13 @@ type Config struct {
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
OmitGetters bool `yaml:"omit_getters,omitempty"`
OmitComplexity bool `yaml:"omit_complexity,omitempty"`
OmitGQLGenFileNotice bool `yaml:"omit_gqlgen_file_notice,omitempty"`
OmitGQLGenVersionInFileNotice bool `yaml:"omit_gqlgen_version_in_file_notice,omitempty"`
StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"`
ReturnPointersInUmarshalInput bool `yaml:"return_pointers_in_unmarshalinput,omitempty"`
ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"`
NullableInputOmittable bool `yaml:"nullable_input_omittable,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
@ -50,7 +56,9 @@ func DefaultConfig() *Config {
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
StructFieldsAlwaysPointers: true,
ReturnPointersInUmarshalInput: false,
ResolversAlwaysReturnPointers: true,
NullableInputOmittable: false,
}
}
@ -97,14 +105,18 @@ var path2regex = strings.NewReplacer(
// LoadConfig reads the gqlgen.yml config file
func LoadConfig(filename string) (*Config, error) {
config := DefaultConfig()
b, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("unable to read config: %w", err)
}
dec := yaml.NewDecoder(bytes.NewReader(b))
return ReadConfig(bytes.NewReader(b))
}
func ReadConfig(cfgFile io.Reader) (*Config, error) {
config := DefaultConfig()
dec := yaml.NewDecoder(cfgFile)
dec.KnownFields(true)
if err := dec.Decode(config); err != nil {

View File

@ -10,12 +10,13 @@ import (
)
type ResolverConfig struct {
Filename string `yaml:"filename,omitempty"`
FilenameTemplate string `yaml:"filename_template,omitempty"`
Package string `yaml:"package,omitempty"`
Type string `yaml:"type,omitempty"`
Layout ResolverLayout `yaml:"layout,omitempty"`
DirName string `yaml:"dir"`
Filename string `yaml:"filename,omitempty"`
FilenameTemplate string `yaml:"filename_template,omitempty"`
Package string `yaml:"package,omitempty"`
Type string `yaml:"type,omitempty"`
Layout ResolverLayout `yaml:"layout,omitempty"`
DirName string `yaml:"dir"`
OmitTemplateComment bool `yaml:"omit_template_comment,omitempty"`
}
type ResolverLayout string

View File

@ -34,6 +34,7 @@ type Data struct {
MutationRoot *Object
SubscriptionRoot *Object
AugmentedSources []AugmentedSource
Plugins []interface{}
}
func (d *Data) HasEmbeddableSources() bool {
@ -76,7 +77,7 @@ func (d *Data) Directives() DirectiveList {
return res
}
func BuildData(cfg *config.Config) (*Data, error) {
func BuildData(cfg *config.Config, plugins ...interface{}) (*Data, error) {
// We reload all packages to allow packages to be compared correctly.
cfg.ReloadAllPackages()
@ -105,6 +106,7 @@ func BuildData(cfg *config.Config) (*Data, error) {
AllDirectives: dataDirectives,
Schema: b.Schema,
Interfaces: map[string]*Interface{},
Plugins: plugins,
}
for _, schemaType := range b.Schema.Types {

View File

@ -3,6 +3,7 @@ package codegen
import (
"errors"
"fmt"
goast "go/ast"
"go/types"
"log"
"reflect"
@ -502,7 +503,21 @@ func (f *Field) ResolverType() string {
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
}
func (f *Field) IsInputObject() bool {
return f.Object.Kind == ast.InputObject
}
func (f *Field) IsRoot() bool {
return f.Object.Root
}
func (f *Field) ShortResolverDeclaration() string {
return f.ShortResolverSignature(nil)
}
// ShortResolverSignature is identical to ShortResolverDeclaration,
// but respects previous naming (return) conventions, if any.
func (f *Field) ShortResolverSignature(ft *goast.FuncType) string {
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
templates.CurrentImports.LookupType(f.Object.Reference()),
@ -523,11 +538,27 @@ func (f *Field) ShortResolverDeclaration() string {
if f.Object.Stream {
result = "<-chan " + result
}
res += fmt.Sprintf(") (%s, error)", result)
// Named return.
var namedV, namedE string
if ft != nil {
if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 {
namedV = ft.Results.List[0].Names[0].Name
}
if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 {
namedE = ft.Results.List[1].Names[0].Name
}
}
res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE)
return res
}
func (f *Field) GoResultName() (string, bool) {
name := fmt.Sprintf("%v", f.TypeReference.GO)
splits := strings.Split(name, "/")
return splits[len(splits)-1], strings.HasPrefix(name, "[]")
}
func (f *Field) ComplexitySignature() string {
res := "func(childComplexity int"
for _, arg := range f.Args {

View File

@ -51,6 +51,7 @@
}
type ComplexityRoot struct {
{{- if not .Config.OmitComplexity }}
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
@ -63,6 +64,7 @@
}
{{- end }}
{{ end }}
{{- end }}
}
{{ end }}
@ -104,6 +106,7 @@
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e}
_ = ec
{{ if not .Config.OmitComplexity -}}
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
@ -130,6 +133,7 @@
{{ end }}
{{ end }}
}
{{- end }}
return 0, false
}

View File

@ -1,6 +1,10 @@
{{- range $input := .Inputs }}
{{- if not .HasUnmarshal }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{.Type | ref}}, error) {
{{- $it := "it" }}
{{- if .PointersInUmarshalInput }}
{{- $it = "&it" }}
{{- end }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{ if .PointersInUmarshalInput }}*{{ end }}{{.Type | ref}}, error) {
var it {{.Type | ref}}
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
@ -31,47 +35,60 @@
{{ template "implDirectives" $field }}
tmp, err := directive{{$field.ImplDirectives|len}}(ctx)
if err != nil {
return it, graphql.ErrorOnPath(ctx, err)
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
if data, ok := tmp.({{ $field.TypeReference.GO | ref }}) ; ok {
{{- if $field.IsResolver }}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
return {{$it}}, err
}
{{- else }}
it.{{$field.GoFieldName}} = data
{{- if $field.TypeReference.IsOmittable }}
it.{{$field.GoFieldName}} = graphql.OmittableOf(data)
{{- else }}
it.{{$field.GoFieldName}} = data
{{- end }}
{{- end }}
{{- if $field.TypeReference.IsNilable }}
{{- if not $field.IsResolver }}
} else if tmp == nil {
it.{{$field.GoFieldName}} = nil
{{- if $field.TypeReference.IsOmittable }}
it.{{$field.GoFieldName}} = graphql.OmittableOf[{{ $field.TypeReference.GO | ref }}](nil)
{{- else }}
it.{{$field.GoFieldName}} = nil
{{- end }}
{{- end }}
{{- end }}
} else {
err := fmt.Errorf(`unexpected type %T from directive, should be {{ $field.TypeReference.GO }}`, tmp)
return it, graphql.ErrorOnPath(ctx, err)
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
{{- else }}
{{- if $field.IsResolver }}
data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
return {{$it}}, err
}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
return {{$it}}, err
}
{{- else }}
it.{{$field.GoFieldName}}, err = ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
return {{$it}}, err
}
{{- if $field.TypeReference.IsOmittable }}
it.{{$field.GoFieldName}} = graphql.OmittableOf(data)
{{- else }}
it.{{$field.GoFieldName}} = data
{{- end }}
{{- end }}
{{- end }}
{{- end }}
}
}
return it, nil
return {{$it}}, nil
}
{{- end }}
{{ end }}

View File

@ -25,14 +25,15 @@ const (
type Object struct {
*ast.Definition
Type types.Type
ResolverInterface types.Type
Root bool
Fields []*Field
Implements []*ast.Definition
DisableConcurrency bool
Stream bool
Directives []*Directive
Type types.Type
ResolverInterface types.Type
Root bool
Fields []*Field
Implements []*ast.Definition
DisableConcurrency bool
Stream bool
Directives []*Directive
PointersInUmarshalInput bool
}
func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
@ -42,11 +43,12 @@ func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
}
caser := cases.Title(language.English, cases.NoLower)
obj := &Object{
Definition: typ,
Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
Definition: typ,
Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
PointersInUmarshalInput: b.Config.ReturnPointersInUmarshalInput,
ResolverInterface: types.NewNamed(
types.NewTypeName(0, b.Config.Exec.Pkg(), caser.String(typ.Name)+"Resolver", nil),
nil,
@ -151,6 +153,16 @@ func (o *Object) Description() string {
return o.Definition.Description
}
func (o *Object) HasField(name string) bool {
for _, f := range o.Fields {
if f.Name == name {
return true
}
}
return false
}
func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.Definition.Name, name) {

View File

@ -49,6 +49,7 @@ type DirectiveRoot struct {
}
type ComplexityRoot struct {
{{- if not .Config.OmitComplexity }}
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
@ -61,6 +62,7 @@ type ComplexityRoot struct {
}
{{- end }}
{{ end }}
{{- end }}
}
type executableSchema struct {
@ -76,6 +78,7 @@ func (e *executableSchema) Schema() *ast.Schema {
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e}
_ = ec
{{- if not .Config.OmitComplexity }}
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
@ -102,6 +105,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
{{ end }}
{{ end }}
}
{{- end }}
return 0, false
}

View File

@ -45,7 +45,7 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
panic("empty ambient import")
}
// if we are referencing our own package we dont need an import
// if we are referencing our own package we don't need an import
if code.ImportPathForDir(s.destDir) == path {
return "", nil
}
@ -85,7 +85,7 @@ func (s *Imports) Lookup(path string) string {
path = code.NormalizeVendor(path)
// if we are referencing our own package we dont need an import
// if we are referencing our own package we don't need an import
if code.ImportPathForDir(s.destDir) == path {
return ""
}

View File

@ -75,9 +75,11 @@
return res, graphql.ErrorOnPath(ctx, err)
{{- else }}
res, err := ec.unmarshalInput{{ $type.GQL.Name }}(ctx, v)
{{- if $type.IsNilable }}
{{- if and $type.IsNilable (not $type.PointersInUmarshalInput) }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else}}
{{- else if and (not $type.IsNilable) $type.PointersInUmarshalInput }}
return *res, graphql.ErrorOnPath(ctx, err)
{{- else }}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
{{- end }}

View File

@ -41,6 +41,7 @@ func findGoInterface(def types.Type) (*types.Interface, error) {
func equalFieldName(source, target string) bool {
source = strings.ReplaceAll(source, "_", "")
source = strings.ReplaceAll(source, ",omitempty", "")
target = strings.ReplaceAll(target, "_", "")
return strings.EqualFold(source, target)
}

View File

@ -6,6 +6,7 @@ import (
"net/http"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// Deprecated: Please update all references to OperationContext instead
@ -106,9 +107,16 @@ func (c *OperationContext) Errorf(ctx context.Context, format string, args ...in
AddErrorf(ctx, format, args...)
}
// Error sends an error to the client, passing it through the formatter.
// Deprecated: use graphql.AddError(ctx, err) instead
// Error add error or multiple errors (if underlaying type is gqlerror.List) into the stack.
// Then it will be sends to the client, passing it through the formatter.
func (c *OperationContext) Error(ctx context.Context, err error) {
if errList, ok := err.(gqlerror.List); ok {
for _, e := range errList {
AddError(ctx, e)
}
return
}
AddError(ctx, err)
}

View File

@ -15,25 +15,25 @@ var _ ExecutableSchema = &ExecutableSchemaMock{}
// ExecutableSchemaMock is a mock implementation of ExecutableSchema.
//
// func TestSomethingThatUsesExecutableSchema(t *testing.T) {
// func TestSomethingThatUsesExecutableSchema(t *testing.T) {
//
// // make and configure a mocked ExecutableSchema
// mockedExecutableSchema := &ExecutableSchemaMock{
// ComplexityFunc: func(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) {
// panic("mock out the Complexity method")
// },
// ExecFunc: func(ctx context.Context) ResponseHandler {
// panic("mock out the Exec method")
// },
// SchemaFunc: func() *ast.Schema {
// panic("mock out the Schema method")
// },
// }
// // make and configure a mocked ExecutableSchema
// mockedExecutableSchema := &ExecutableSchemaMock{
// ComplexityFunc: func(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) {
// panic("mock out the Complexity method")
// },
// ExecFunc: func(ctx context.Context) ResponseHandler {
// panic("mock out the Exec method")
// },
// SchemaFunc: func() *ast.Schema {
// panic("mock out the Schema method")
// },
// }
//
// // use mockedExecutableSchema in code that requires ExecutableSchema
// // and then make assertions.
// // use mockedExecutableSchema in code that requires ExecutableSchema
// // and then make assertions.
//
// }
// }
type ExecutableSchemaMock struct {
// ComplexityFunc mocks the Complexity method.
ComplexityFunc func(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
@ -95,7 +95,8 @@ func (mock *ExecutableSchemaMock) Complexity(typeName string, fieldName string,
// ComplexityCalls gets all the calls that were made to Complexity.
// Check the length with:
// len(mockedExecutableSchema.ComplexityCalls())
//
// len(mockedExecutableSchema.ComplexityCalls())
func (mock *ExecutableSchemaMock) ComplexityCalls() []struct {
TypeName string
FieldName string
@ -132,7 +133,8 @@ func (mock *ExecutableSchemaMock) Exec(ctx context.Context) ResponseHandler {
// ExecCalls gets all the calls that were made to Exec.
// Check the length with:
// len(mockedExecutableSchema.ExecCalls())
//
// len(mockedExecutableSchema.ExecCalls())
func (mock *ExecutableSchemaMock) ExecCalls() []struct {
Ctx context.Context
} {
@ -160,7 +162,8 @@ func (mock *ExecutableSchemaMock) Schema() *ast.Schema {
// SchemaCalls gets all the calls that were made to Schema.
// Check the length with:
// len(mockedExecutableSchema.SchemaCalls())
//
// len(mockedExecutableSchema.SchemaCalls())
func (mock *ExecutableSchemaMock) SchemaCalls() []struct {
} {
var calls []struct {

View File

@ -4,17 +4,17 @@ import (
"context"
"github.com/99designs/gqlgen/graphql"
lru "github.com/hashicorp/golang-lru"
lru "github.com/hashicorp/golang-lru/v2"
)
type LRU struct {
lru *lru.Cache
lru *lru.Cache[string, any]
}
var _ graphql.Cache = &LRU{}
func New(size int) *LRU {
cache, err := lru.New(size)
cache, err := lru.New[string, any](size)
if err != nil {
// An error is only returned for non-positive cache size
// and we already checked for that.

View File

@ -0,0 +1,17 @@
package transport
import "net/http"
func writeHeaders(w http.ResponseWriter, headers map[string][]string) {
if len(headers) == 0 {
headers = map[string][]string{
"Content-Type": {"application/json"},
}
}
for key, values := range headers {
for _, value := range values {
w.Header().Add(key, value)
}
}
}

View File

@ -20,6 +20,10 @@ type MultipartForm struct {
// as multipart/form-data in memory, with the remainder stored on disk in
// temporary files.
MaxMemory int64
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = MultipartForm{}
@ -52,7 +56,7 @@ func (f MultipartForm) maxMemory() int64 {
}
func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
w.Header().Set("Content-Type", "application/json")
writeHeaders(w, f.ResponseHeaders)
start := graphql.Now()

View File

@ -0,0 +1,119 @@
package transport
import (
"io"
"mime"
"net/http"
"net/url"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
// FORM implements the application/x-www-form-urlencoded side of the default HTTP transport
type UrlEncodedForm struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = UrlEncodedForm{}
func (h UrlEncodedForm) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "application/x-www-form-urlencoded"
}
func (h UrlEncodedForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
writeHeaders(w, h.ResponseHeaders)
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf("could not get form body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
params, err = h.parseBody(bodyString)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
gqlErr := gqlerror.Errorf("could not cleanup body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
var responses graphql.ResponseHandler
responses, ctx = exec.DispatchOperation(ctx, rc)
writeJson(w, responses(ctx))
}
func (h UrlEncodedForm) parseBody(bodyString string) (*graphql.RawParams, error) {
switch {
case strings.Contains(bodyString, "\"query\":"):
// body is json
return h.parseJson(bodyString)
case strings.HasPrefix(bodyString, "query=%7B"):
// body is urlencoded
return h.parseEncoded(bodyString)
default:
// body is plain text
params := &graphql.RawParams{}
params.Query = strings.TrimPrefix(bodyString, "query=")
return params, nil
}
}
func (h UrlEncodedForm) parseEncoded(bodyString string) (*graphql.RawParams, error) {
params := &graphql.RawParams{}
query, err := url.QueryUnescape(bodyString)
if err != nil {
return nil, err
}
params.Query = strings.TrimPrefix(query, "query=")
return params, nil
}
func (h UrlEncodedForm) parseJson(bodyString string) (*graphql.RawParams, error) {
params := &graphql.RawParams{}
bodyReader := io.NopCloser(strings.NewReader(bodyString))
err := jsonDecode(bodyReader, &params)
if err != nil {
return nil, err
}
return params, nil
}

View File

@ -15,7 +15,11 @@ import (
// GET implements the GET side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#get
type GET struct{}
type GET struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = GET{}
@ -34,7 +38,7 @@ func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut
writeJsonError(w, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
writeHeaders(w, h.ResponseHeaders)
raw := &graphql.RawParams{
Query: query.Get("query"),

View File

@ -0,0 +1,98 @@
package transport
import (
"mime"
"net/http"
"net/url"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
// GRAPHQL implements the application/graphql side of the HTTP transport
// see: https://graphql.org/learn/serving-over-http/#post-request
// If the "application/graphql" Content-Type header is present, treat
// the HTTP POST body contents as the GraphQL query string.
type GRAPHQL struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = GRAPHQL{}
func (h GRAPHQL) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "application/graphql"
}
func (h GRAPHQL) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
writeHeaders(w, h.ResponseHeaders)
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
params.Query, err = cleanupBody(bodyString)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
gqlErr := gqlerror.Errorf("could not cleanup body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
var responses graphql.ResponseHandler
responses, ctx = exec.DispatchOperation(ctx, rc)
writeJson(w, responses(ctx))
}
// Makes sure we strip "query=" keyword from body and
// that body is not url escaped
func cleanupBody(body string) (out string, err error) {
// Some clients send 'query=' at the start of body payload. Let's remove
// it to get GQL query only.
body = strings.TrimPrefix(body, "query=")
// Body payload can be url encoded or not. We check if %7B - "{" character
// is where query starts. If it is, query is url encoded.
if strings.HasPrefix(body, "%7B") {
body, err = url.QueryUnescape(body)
if err != nil {
return body, err
}
}
return body, err
}

View File

@ -1,15 +1,24 @@
package transport
import (
"fmt"
"io"
"mime"
"net/http"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
// POST implements the POST side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#post
type POST struct{}
type POST struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = POST{}
@ -26,31 +35,58 @@ func (h POST) Supports(r *http.Request) bool {
return r.Method == "POST" && mediaType == "application/json"
}
func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
w.Header().Set("Content-Type", "application/json")
var params *graphql.RawParams
start := graphql.Now()
if err := jsonDecode(r.Body, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonErrorf(w, "json body could not be decoded: "+err.Error())
return
func getRequestBody(r *http.Request) (string, error) {
if r == nil || r.Body == nil {
return "", nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return "", fmt.Errorf("unable to get Request Body %w", err)
}
return string(body), nil
}
func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
writeHeaders(w, h.ResponseHeaders)
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
rc, err := exec.CreateOperationContext(r.Context(), params)
bodyString, err := getRequestBody(r)
if err != nil {
w.WriteHeader(statusFor(err))
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), err)
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
var responses graphql.ResponseHandler
responses, ctx = exec.DispatchOperation(ctx, rc)
writeJson(w, responses(ctx))
}

View File

@ -0,0 +1,110 @@
package transport
import (
"encoding/json"
"fmt"
"io"
"log"
"mime"
"net/http"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
type SSE struct{}
var _ graphql.Transport = SSE{}
func (t SSE) Supports(r *http.Request) bool {
if !strings.Contains(r.Header.Get("Accept"), "text/event-stream") {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == http.MethodPost && mediaType == "application/json"
}
func (t SSE) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
flusher, ok := w.(http.Flusher)
if !ok {
SendErrorf(w, http.StatusInternalServerError, "streaming unsupported")
return
}
defer flusher.Flush()
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Content-Type", "application/json")
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("could not get json request body: %+v", err.Error())
writeJson(w, resp)
return
}
bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("decoding error: %+v body:%s", err.Error(), bodyString)
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
ctx = graphql.WithOperationContext(ctx, rc)
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprint(w, ":\n\n")
flusher.Flush()
responses, ctx := exec.DispatchOperation(ctx, rc)
for {
response := responses(ctx)
if response == nil {
break
}
writeJsonWithSSE(w, response)
flusher.Flush()
}
fmt.Fprint(w, "event: complete\n\n")
}
func writeJsonWithSSE(w io.Writer, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
fmt.Fprintf(w, "event: next\ndata: %s\n\n", b)
}

View File

@ -24,6 +24,7 @@ type (
InitFunc WebsocketInitFunc
InitTimeout time.Duration
ErrorFunc WebsocketErrorFunc
CloseFunc WebsocketCloseFunc
KeepAlivePingInterval time.Duration
PingPongInterval time.Duration
@ -45,6 +46,9 @@ type (
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketErrorFunc func(ctx context.Context, err error)
// Callback called when websocket is closed.
WebsocketCloseFunc func(ctx context.Context, closeCode int)
)
var errReadTimeout = errors.New("read timeout")
@ -350,6 +354,7 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) {
c.mu.Unlock()
go func() {
ctx = withSubscriptionErrorContext(ctx)
defer func() {
if r := recover(); r != nil {
err := rc.Recover(ctx, r)
@ -362,7 +367,11 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) {
}
c.sendError(msg.id, gqlerr)
}
c.complete(msg.id)
if errs := getSubscriptionError(ctx); len(errs) != 0 {
c.sendError(msg.id, errs...)
} else {
c.complete(msg.id)
}
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
@ -428,4 +437,8 @@ func (c *wsConnection) close(closeCode int, message string) {
}
c.mu.Unlock()
_ = c.conn.Close()
if c.CloseFunc != nil {
c.CloseFunc(c.ctx, closeCode)
}
}

View File

@ -0,0 +1,69 @@
package transport
import (
"context"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// A private key for context that only this package can access. This is important
// to prevent collisions between different context uses
var wsSubscriptionErrorCtxKey = &wsSubscriptionErrorContextKey{"subscription-error"}
type wsSubscriptionErrorContextKey struct {
name string
}
type subscriptionError struct {
errs []*gqlerror.Error
}
// AddSubscriptionError is used to let websocket return an error message after subscription resolver returns a channel.
// for example:
//
// func (r *subscriptionResolver) Method(ctx context.Context) (<-chan *model.Message, error) {
// ch := make(chan *model.Message)
// go func() {
// defer func() {
// close(ch)
// }
// // some kind of block processing (e.g.: gRPC client streaming)
// stream, err := gRPCClientStreamRequest(ctx)
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// for {
// m, err := stream.Recv()
// if err == io.EOF {
// return
// }
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// ch <- m
// }
// }()
//
// return ch, nil
// }
//
// see https://github.com/99designs/gqlgen/pull/2506 for more details
func AddSubscriptionError(ctx context.Context, err *gqlerror.Error) {
subscriptionErrStruct := getSubscriptionErrorStruct(ctx)
subscriptionErrStruct.errs = append(subscriptionErrStruct.errs, err)
}
func withSubscriptionErrorContext(ctx context.Context) context.Context {
return context.WithValue(ctx, wsSubscriptionErrorCtxKey, &subscriptionError{})
}
func getSubscriptionErrorStruct(ctx context.Context) *subscriptionError {
v, _ := ctx.Value(wsSubscriptionErrorCtxKey).(*subscriptionError)
return v
}
func getSubscriptionError(ctx context.Context) []*gqlerror.Error {
return getSubscriptionErrorStruct(ctx).errs
}

View File

@ -0,0 +1,35 @@
package graphql
// Omittable is a wrapper around a value that also stores whether it is set
// or not.
type Omittable[T any] struct {
value T
set bool
}
func OmittableOf[T any](value T) Omittable[T] {
return Omittable[T]{
value: value,
set: true,
}
}
func (o Omittable[T]) Value() T {
if !o.set {
var zero T
return zero
}
return o.value
}
func (o Omittable[T]) ValueOK() (T, bool) {
if !o.set {
var zero T
return zero, false
}
return o.value, true
}
func (o Omittable[T]) IsSet() bool {
return o.set
}

View File

@ -0,0 +1,84 @@
package playground
import (
"html/template"
"net/http"
)
var altairPage = template.Must(template.New("altair").Parse(`<!doctype html>
<html>
<head>
<meta charset="utf-8">
<title>{{.title}}</title>
<base href="https://cdn.jsdelivr.net/npm/altair-static@{{.version}}/build/dist/">
<meta name="viewport" content="width=device-width,initial-scale=1">
<link rel="icon" type="image/x-icon" href="favicon.ico">
<link href="styles.css" rel="stylesheet" crossorigin="anonymous" integrity="{{.cssSRI}}"/>
</head>
<body>
<app-root>
<style>
.loading-screen {
display: none;
}
</style>
<div class="loading-screen styled">
<div class="loading-screen-inner">
<div class="loading-screen-logo-container">
<img src="assets/img/logo_350.svg" alt="Altair">
</div>
<div class="loading-screen-loading-indicator">
<span class="loading-indicator-dot"></span>
<span class="loading-indicator-dot"></span>
<span class="loading-indicator-dot"></span>
</div>
</div>
</div>
</app-root>
<script rel="preload" as="script" type="text/javascript" crossorigin="anonymous" integrity="{{.mainSRI}}" src="main.js"></script>
<script rel="preload" as="script" type="text/javascript" crossorigin="anonymous" integrity="{{.polyfillsSRI}}" src="polyfills.js"></script>
<script rel="preload" as="script" type="text/javascript" crossorigin="anonymous" integrity="{{.runtimeSRI}}" src="runtime.js"></script>
<script>
{{- if .endpointIsAbsolute}}
const url = {{.endpoint}};
const subscriptionUrl = {{.subscriptionEndpoint}};
{{- else}}
const url = location.protocol + '//' + location.host + {{.endpoint}};
const wsProto = location.protocol == 'https:' ? 'wss:' : 'ws:';
const subscriptionUrl = wsProto + '//' + location.host + {{.endpoint}};
{{- end}}
var altairOptions = {
endpointURL: url,
subscriptionsEndpoint: subscriptionUrl,
};
window.addEventListener("load", function() {
AltairGraphQL.init(altairOptions);
});
</script>
</body>
</html>`))
// AltairHandler responsible for setting up the altair playground
func AltairHandler(title, endpoint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := altairPage.Execute(w, map[string]interface{}{
"title": title,
"endpoint": endpoint,
"endpointIsAbsolute": endpointHasScheme(endpoint),
"subscriptionEndpoint": getSubscriptionEndpoint(endpoint),
"version": "5.0.5",
"cssSRI": "sha256-kZ35e5mdMYN5ALEbnsrA2CLn85Oe4hBodfsih9BqNxs=",
"mainSRI": "sha256-nWdVTcGTlBDV1L04UQnqod+AJedzBCnKHv6Ct65liHE=",
"polyfillsSRI": "sha256-1aVEg2sROcCQ/RxU3AlcPaRZhZdIWA92q2M+mdd/R4c=",
"runtimeSRI": "sha256-cK2XhXqQr0WS1Z5eKNdac0rJxTD6miC3ubd+aEVMQDk=",
})
if err != nil {
panic(err)
}
}
}

View File

@ -0,0 +1,62 @@
package playground
import (
"html/template"
"net/http"
)
var apolloSandboxPage = template.Must(template.New("ApolloSandbox").Parse(`<!doctype html>
<html>
<head>
<meta charset="utf-8">
<title>{{.title}}</title>
<meta name="viewport" content="width=device-width,initial-scale=1">
<link rel="icon" href="https://embeddable-sandbox.cdn.apollographql.com/_latest/public/assets/favicon-dark.png">
<style>
body {
margin: 0;
overflow: hidden;
}
</style>
</head>
<body>
<div style="width: 100vw; height: 100vh;" id='embedded-sandbox'></div>
<!-- NOTE: New version available at https://embeddable-sandbox.cdn.apollographql.com/ -->
<script rel="preload" as="script" crossorigin="anonymous" integrity="{{.mainSRI}}" type="text/javascript" src="https://embeddable-sandbox.cdn.apollographql.com/58165cf7452dbad480c7cb85e7acba085b3bac1d/embeddable-sandbox.umd.production.min.js"></script>
<script>
{{- if .endpointIsAbsolute}}
const url = {{.endpoint}};
{{- else}}
const url = location.protocol + '//' + location.host + {{.endpoint}};
{{- end}}
<!-- See https://www.apollographql.com/docs/graphos/explorer/sandbox/#options -->
new window.EmbeddedSandbox({
target: '#embedded-sandbox',
initialEndpoint: url,
persistExplorerState: true,
initialState: {
includeCookies: true,
pollForSchemaUpdates: false,
}
});
</script>
</body>
</html>`))
// ApolloSandboxHandler responsible for setting up the altair playground
func ApolloSandboxHandler(title, endpoint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := apolloSandboxPage.Execute(w, map[string]interface{}{
"title": title,
"endpoint": endpoint,
"endpointIsAbsolute": endpointHasScheme(endpoint),
"mainSRI": "sha256-/E4VNgAWFmbNLyXACSYoqsDAj68jC1sCMSQ0cDjf4YM=",
})
if err != nil {
panic(err)
}
}
}

View File

@ -12,7 +12,7 @@ type Stats struct {
Parsing TraceTiming
Validation TraceTiming
// Stats collected by handler extensions. Dont use directly, the extension should provide a type safe way to
// Stats collected by handler extensions. Don't use directly, the extension should provide a type safe way to
// access this.
extension map[string]interface{}
}
@ -26,7 +26,7 @@ var ctxTraceStart key = "trace_start"
// StartOperationTrace captures the current time and stores it in context. This will eventually be added to request
// context but we want to grab it as soon as possible. For transports that can only handle a single graphql query
// per http requests you dont need to call this at all, the server will do it for you. For transports that handle
// per http requests you don't need to call this at all, the server will do it for you. For transports that handle
// multiple (eg batching, subscriptions) this should be called before decoding each request.
func StartOperationTrace(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxTraceStart, Now())

View File

@ -1,6 +1,7 @@
package graphql
import (
"encoding/json"
"fmt"
"io"
"strconv"
@ -55,7 +56,9 @@ func UnmarshalString(v interface{}) (string, error) {
case int64:
return strconv.FormatInt(v, 10), nil
case float64:
return fmt.Sprintf("%f", v), nil
return strconv.FormatFloat(v, 'f', -1, 64), nil
case json.Number:
return string(v), nil
case bool:
if v {
return "true", nil

View File

@ -1,3 +1,3 @@
package graphql
const Version = "v0.17.20"
const Version = "v0.17.31"

View File

@ -4,13 +4,13 @@ schema:
# Where should the generated server code go?
exec:
filename: graph/generated/generated.go
package: generated
filename: graph/generated.go
package: graph
# Uncomment to enable federation
# federation:
# filename: graph/generated/federation.go
# package: generated
# filename: graph/federation.go
# package: graph
# Where should any generated models go?
model:
@ -22,6 +22,9 @@ resolver:
layout: follow-schema
dir: graph
package: graph
filename_template: "{name}.resolvers.go"
# Optional: turn on to not generate template comments above resolvers
# omit_template_comment: false
# Optional: turn on use ` + "`" + `gqlgen:"fieldName"` + "`" + ` tags in your models
# struct_tag: json
@ -29,6 +32,15 @@ resolver:
# Optional: turn on to use []Thing instead of []*Thing
# omit_slice_element_pointers: false
# Optional: turn on to skip generation of ComplexityRoot struct content and Complexity function
# omit_complexity: false
# Optional: turn on to not generate any file notice comments in generated files
# omit_gqlgen_file_notice: false
# Optional: turn on to exclude the gqlgen version in the generated file notice. No effect if `omit_gqlgen_file_notice` is true.
# omit_gqlgen_version_in_file_notice: false
# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true
@ -36,9 +48,18 @@ resolver:
# Optional: turn off to make resolvers return values instead of pointers for structs
# resolvers_always_return_pointers: true
# Optional: turn on to return pointers instead of values in unmarshalInput
# return_pointers_in_unmarshalinput: false
# Optional: wrap nullable input fields with Omittable
# nullable_input_omittable: true
# Optional: set to speed up generation time by not performing a final validation pass.
# skip_validation: true
# Optional: set to skip running `go mod tidy` when generating server code
# skip_mod_tidy: true
# gqlgen will search for any type names in the schema in these go packages
# if they match it will use them, otherwise it will generate them.
autobind:

View File

@ -7,10 +7,16 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime/debug"
"strings"
"sync"
"golang.org/x/tools/go/packages"
)
var once = sync.Once{}
var modInfo *debug.BuildInfo
var mode = packages.NeedName |
packages.NeedFiles |
packages.NeedImports |
@ -31,20 +37,39 @@ type Packages struct {
numNameCalls int // stupid test steam. ignore.
}
func (p *Packages) CleanupUserPackages() {
once.Do(func() {
var ok bool
modInfo, ok = debug.ReadBuildInfo()
if !ok {
modInfo = nil
}
})
// Don't cleanup github.com/99designs/gqlgen prefixed packages, they haven't changed and do not need to be reloaded
if modInfo != nil {
var toRemove []string
for k := range p.packages {
if !strings.HasPrefix(k, modInfo.Main.Path) {
toRemove = append(toRemove, k)
}
}
for _, k := range toRemove {
delete(p.packages, k)
}
} else {
p.packages = nil // Cleanup all packages if we don't know for some reason which ones to keep
}
}
// ReloadAll will call LoadAll after clearing the package cache, so we can reload
// packages in the case that the packages have changed
func (p *Packages) ReloadAll(importPaths ...string) []*packages.Package {
p.packages = nil
return p.LoadAll(importPaths...)
}
func (p *Packages) checkModuleLoaded(pkgs []*packages.Package) bool {
for i := range pkgs {
if pkgs[i] == nil || pkgs[i].Module == nil {
return false
}
if p.packages != nil {
p.CleanupUserPackages()
}
return true
return p.LoadAll(importPaths...)
}
// LoadAll will call packages.Load and return the package data for the given packages,
@ -65,13 +90,6 @@ func (p *Packages) LoadAll(importPaths ...string) []*packages.Package {
if len(missing) > 0 {
p.numLoadCalls++
pkgs, err := packages.Load(&packages.Config{Mode: mode}, missing...)
// Sometimes packages.Load not loaded the module info. Call it again to reload it.
if !p.checkModuleLoaded(pkgs) {
fmt.Println("reloading module info")
pkgs, err = packages.Load(&packages.Config{Mode: mode}, missing...)
}
if err != nil {
p.loadErrors = append(p.loadErrors, err)
}
@ -151,7 +169,7 @@ func (p *Packages) NameForPackage(importPath string) string {
pkg := p.packages[importPath]
if pkg == nil {
// otherwise do a name only lookup for it but dont put it in the package cache.
// otherwise do a name only lookup for it but don't put it in the package cache.
p.numNameCalls++
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, importPath)
if err != nil {

View File

@ -68,7 +68,7 @@ func (r *Rewriter) getFile(filename string) string {
return r.files[filename]
}
func (r *Rewriter) GetMethodComment(structname string, methodname string) string {
func (r *Rewriter) GetPrevDecl(structname string, methodname string) *ast.FuncDecl {
for _, f := range r.pkg.Syntax {
for _, d := range f.Decls {
d, isFunc := d.(*ast.FuncDecl)
@ -89,48 +89,29 @@ func (r *Rewriter) GetMethodComment(structname string, methodname string) string
if !ok {
continue
}
if ident.Name != structname {
continue
}
return d.Doc.Text()
r.copied[d] = true
return d
}
}
return nil
}
func (r *Rewriter) GetMethodComment(structname string, methodname string) string {
d := r.GetPrevDecl(structname, methodname)
if d != nil {
return d.Doc.Text()
}
return ""
}
func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
for _, f := range r.pkg.Syntax {
for _, d := range f.Decls {
d, isFunc := d.(*ast.FuncDecl)
if !isFunc {
continue
}
if d.Name.Name != methodname {
continue
}
if d.Recv == nil || len(d.Recv.List) == 0 {
continue
}
recv := d.Recv.List[0].Type
if star, isStar := recv.(*ast.StarExpr); isStar {
recv = star.X
}
ident, ok := recv.(*ast.Ident)
if !ok {
continue
}
if ident.Name != structname {
continue
}
r.copied[d] = true
return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
}
d := r.GetPrevDecl(structname, methodname)
if d != nil {
return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
}
return ""
}

1167
vendor/github.com/99designs/gqlgen/lint.txt generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -39,6 +39,28 @@ func fileExists(filename string) bool {
return !errors.Is(err, fs.ErrNotExist)
}
// see Go source code:
// https://github.com/golang/go/blob/f57ebed35132d02e5cf016f324853217fb545e91/src/cmd/go/internal/modload/init.go#L1283
func findModuleRoot(dir string) (roots string) {
if dir == "" {
panic("dir not set")
}
dir = filepath.Clean(dir)
// Look for enclosing go.mod.
for {
if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() {
return dir
}
d := filepath.Dir(dir)
if d == dir { // the parent of the root is itself, so we can go no further
break
}
dir = d
}
return ""
}
func initFile(filename, contents string) error {
if err := os.MkdirAll(filepath.Dir(filename), 0o755); err != nil {
return fmt.Errorf("unable to create directory for file '%s': %w\n", filename, err)
@ -56,17 +78,36 @@ var initCmd = &cli.Command{
Flags: []cli.Flag{
&cli.BoolFlag{Name: "verbose, v", Usage: "show logs"},
&cli.StringFlag{Name: "config, c", Usage: "the config filename", Value: "gqlgen.yml"},
&cli.StringFlag{Name: "server", Usage: "where to write the server stub to", Value: "server.go"},
&cli.StringFlag{Name: "schema", Usage: "where to write the schema stub to", Value: "graph/schema.graphqls"},
&cli.StringFlag{
Name: "server",
Usage: "where to write the server stub to",
Value: "server.go",
},
&cli.StringFlag{
Name: "schema",
Usage: "where to write the schema stub to",
Value: "graph/schema.graphqls",
},
},
Action: func(ctx *cli.Context) error {
configFilename := ctx.String("config")
serverFilename := ctx.String("server")
schemaFilename := ctx.String("schema")
pkgName := code.ImportPathForDir(".")
cwd, err := os.Getwd()
if err != nil {
log.Println(err)
return fmt.Errorf("unable to determine current directory:%w", err)
}
pkgName := code.ImportPathForDir(cwd)
if pkgName == "" {
return fmt.Errorf("unable to determine import path for current directory, you probably need to run 'go mod init' first")
return fmt.Errorf(
"unable to determine import path for current directory, you probably need to run 'go mod init' first",
)
}
modRoot := findModuleRoot(cwd)
if modRoot == "" {
return fmt.Errorf("go.mod is missing. Please, do 'go mod init' first\n")
}
// check schema and config don't already exist
@ -75,7 +116,7 @@ var initCmd = &cli.Command{
return fmt.Errorf("%s already exists", filename)
}
}
_, err := config.LoadConfigFromDefaultLocations()
_, err = config.LoadConfigFromDefaultLocations()
if err == nil {
return fmt.Errorf("gqlgen.yml already exists in a parent directory\n")
}

View File

@ -0,0 +1,117 @@
package federation
import (
"go/types"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/plugin/federation/fieldset"
"github.com/vektah/gqlparser/v2/ast"
)
// Entity represents a federated type
// that was declared in the GQL schema.
type Entity struct {
Name string // The same name as the type declaration
Def *ast.Definition
Resolvers []*EntityResolver
Requires []*Requires
Multi bool
}
type EntityResolver struct {
ResolverName string // The resolver name, such as FindUserByID
KeyFields []*KeyField // The fields declared in @key.
InputType types.Type // The Go generated input type for multi entity resolvers
InputTypeName string
}
func (e *EntityResolver) LookupInputType() string {
return templates.CurrentImports.LookupType(e.InputType)
}
type KeyField struct {
Definition *ast.FieldDefinition
Field fieldset.Field // len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
// Requires represents an @requires clause
type Requires struct {
Name string // the name of the field
Field fieldset.Field // source Field, len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
func (e *Entity) allFieldsAreExternal(federationVersion int) bool {
for _, field := range e.Def.Fields {
if !e.isFieldImplicitlyExternal(field, federationVersion) && field.Directives.ForName("external") == nil {
return false
}
}
return true
}
// In federation v2, key fields are implicitly external.
func (e *Entity) isFieldImplicitlyExternal(field *ast.FieldDefinition, federationVersion int) bool {
// Key fields are only implicitly external in Federation 2
if federationVersion != 2 {
return false
}
// TODO: From the spec, it seems like if an entity is not resolvable then it should not only not have a resolver, but should not appear in the _Entitiy union.
// The current implementation is a less drastic departure from the previous behavior, but should probably be reviewed.
// See https://www.apollographql.com/docs/federation/subgraph-spec/
if e.isResolvable() {
return false
}
// If the field is a key field, it is implicitly external
if e.isKeyField(field) {
return true
}
return false
}
// Determine if the entity is resolvable.
func (e *Entity) isResolvable() bool {
key := e.Def.Directives.ForName("key")
if key == nil {
// If there is no key directive, the entity is resolvable.
return true
}
resolvable := key.Arguments.ForName("resolvable")
if resolvable == nil {
// If there is no resolvable argument, the entity is resolvable.
return true
}
// only if resolvable: false has been set on the @key directive do we consider the entity non-resolvable.
return resolvable.Value.Raw != "false"
}
// Determine if a field is part of the entities key.
func (e *Entity) isKeyField(field *ast.FieldDefinition) bool {
for _, keyField := range e.keyFields() {
if keyField == field.Name {
return true
}
}
return false
}
// Get the key fields for this entity.
func (e *Entity) keyFields() []string {
key := e.Def.Directives.ForName("key")
if key == nil {
return []string{}
}
fields := key.Arguments.ForName("fields")
if fields == nil {
return []string{}
}
fieldSet := fieldset.New(fields.Value.Raw, nil)
keyFields := make([]string, len(fieldSet))
for i, field := range fieldSet {
keyFields[i] = field[0]
}
return keyFields
}

View File

@ -89,8 +89,6 @@ func (f *federation) InjectSourceEarly() *ast.Source {
input := `
scalar _Any
scalar _FieldSet
directive @external on FIELD_DEFINITION
directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
directive @extends on OBJECT | INTERFACE
@ -99,10 +97,12 @@ func (f *federation) InjectSourceEarly() *ast.Source {
if f.Version == 1 {
input += `
directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
directive @external on FIELD_DEFINITION
`
} else if f.Version == 2 {
input += `
directive @key(fields: _FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
directive @external on FIELD_DEFINITION | OBJECT
directive @link(import: [String!], url: String!) repeatable on SCHEMA
directive @shareable on OBJECT | FIELD_DEFINITION
directive @tag(name: String!) repeatable on FIELD_DEFINITION | INTERFACE | OBJECT | UNION | ARGUMENT_DEFINITION | SCALAR | ENUM | ENUM_VALUE | INPUT_OBJECT | INPUT_FIELD_DEFINITION
@ -134,12 +134,12 @@ func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
if entityResolverInputDefinitions != "" {
entityResolverInputDefinitions += "\n\n"
}
entityResolverInputDefinitions += "input " + r.InputType + " {\n"
entityResolverInputDefinitions += "input " + r.InputTypeName + " {\n"
for _, keyField := range r.KeyFields {
entityResolverInputDefinitions += fmt.Sprintf("\t%s: %s\n", keyField.Field.ToGo(), keyField.Definition.Type.String())
}
entityResolverInputDefinitions += "}"
resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputType, e.Name)
resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputTypeName, e.Name)
} else {
resolverArgs := ""
for _, keyField := range r.KeyFields {
@ -198,44 +198,6 @@ type Entity {
}
}
// Entity represents a federated type
// that was declared in the GQL schema.
type Entity struct {
Name string // The same name as the type declaration
Def *ast.Definition
Resolvers []*EntityResolver
Requires []*Requires
Multi bool
}
type EntityResolver struct {
ResolverName string // The resolver name, such as FindUserByID
KeyFields []*KeyField // The fields declared in @key.
InputType string // The Go generated input type for multi entity resolvers
}
type KeyField struct {
Definition *ast.FieldDefinition
Field fieldset.Field // len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
// Requires represents an @requires clause
type Requires struct {
Name string // the name of the field
Field fieldset.Field // source Field, len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
func (e *Entity) allFieldsAreExternal() bool {
for _, field := range e.Def.Fields {
if field.Directives.ForName("external") == nil {
return false
}
}
return true
}
func (f *federation) GenerateCode(data *codegen.Data) error {
if len(f.Entities) > 0 {
if data.Objects.ByName("Entity") != nil {
@ -272,6 +234,23 @@ func (f *federation) GenerateCode(data *codegen.Data) error {
}
}
// fill in types for resolver inputs
//
for _, entity := range f.Entities {
if !entity.Multi {
continue
}
for _, resolver := range entity.Resolvers {
obj := data.Inputs.ByName(resolver.InputTypeName)
if obj == nil {
return fmt.Errorf("input object %s not found", resolver.InputTypeName)
}
resolver.InputType = obj.Type
}
}
return templates.Render(templates.Options{
PackageName: data.Config.Federation.Package,
Filename: data.Config.Federation.Filename,
@ -323,7 +302,7 @@ func (f *federation) setEntities(schema *ast.Schema) {
// extend TypeDefinedInOtherService @key(fields: "id") {
// id: ID @external
// }
if !e.allFieldsAreExternal() {
if !e.allFieldsAreExternal(f.Version) {
for _, dir := range keys {
if len(dir.Arguments) > 2 {
panic("More than two arguments provided for @key declaration.")
@ -365,9 +344,9 @@ func (f *federation) setEntities(schema *ast.Schema) {
}
e.Resolvers = append(e.Resolvers, &EntityResolver{
ResolverName: resolverName,
KeyFields: keyFields,
InputType: resolverFieldsToGo + "Input",
ResolverName: resolverName,
KeyFields: keyFields,
InputTypeName: resolverFieldsToGo + "Input",
})
}

View File

@ -133,7 +133,7 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati
{{ if and .Resolvers .Multi -}}
case "{{.Def.Name}}":
{{range $i, $_ := .Resolvers -}}
_reps := make([]*{{.InputType}}, len(reps))
_reps := make([]*{{.LookupInputType}}, len(reps))
for i, rep := range reps {
{{ range $i, $keyField := .KeyFields -}}
@ -143,7 +143,7 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati
}
{{end}}
_reps[i] = &{{.InputType}} {
_reps[i] = &{{.LookupInputType}} {
{{ range $i, $keyField := .KeyFields -}}
{{$keyField.Field.ToGo}}: id{{$i}},
{{end}}

View File

@ -17,16 +17,23 @@ import (
//go:embed models.gotpl
var modelTemplate string
type BuildMutateHook = func(b *ModelBuild) *ModelBuild
type (
BuildMutateHook = func(b *ModelBuild) *ModelBuild
FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
)
type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
// defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook.
func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
// DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook.
func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
var err error
f, err = GoFieldHook(td, fd, f)
if err != nil {
return f, err
}
return GoTagFieldHook(td, fd, f)
}
func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
// DefaultBuildMutateHook is the default hook for the Plugin which mutate ModelBuild.
func DefaultBuildMutateHook(b *ModelBuild) *ModelBuild {
return b
}
@ -57,9 +64,10 @@ type Field struct {
// Name is the field's name as it appears in the schema
Name string
// GoName is the field's name as it appears in the generated Go code
GoName string
Type types.Type
Tag string
GoName string
Type types.Type
Tag string
Omittable bool
}
type Enum struct {
@ -75,8 +83,8 @@ type EnumValue struct {
func New() plugin.Plugin {
return &Plugin{
MutateHook: defaultBuildMutateHook,
FieldHook: defaultFieldMutateHook,
MutateHook: DefaultBuildMutateHook,
FieldHook: DefaultFieldMutateHook,
}
}
@ -92,7 +100,6 @@ func (m *Plugin) Name() string {
}
func (m *Plugin) MutateConfig(cfg *config.Config) error {
b := &ModelBuild{
PackageName: cfg.Model.Package,
}
@ -298,6 +305,8 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition)
binder := cfg.NewBinder()
fields := make([]*Field, 0)
var omittableType types.Type
for _, field := range schemaType.Fields {
var typ types.Type
fieldDef := cfg.Schema.Types[field.Type.Name()]
@ -365,7 +374,8 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition)
GoName: name,
Type: typ,
Description: field.Description,
Tag: `json:"` + field.Name + `"`,
Tag: getStructTagFromField(field),
Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull,
}
if m.FieldHook != nil {
@ -376,14 +386,42 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition)
f = mf
}
if f.Omittable {
if schemaType.Kind != ast.InputObject || field.Type.NonNull {
return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name)
}
var err error
if omittableType == nil {
omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable")
if err != nil {
return nil, err
}
}
f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type})
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
}
fields = append(fields, f)
}
return fields, nil
}
// GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field
// name is used when no value argument is present.
func getStructTagFromField(field *ast.FieldDefinition) string {
if !field.Type.NonNull {
return `json:"` + field.Name + `,omitempty"`
}
return `json:"` + field.Name + `"`
}
// GoTagFieldHook prepends the goTag directive to the generated Field f.
// When applying the Tag to the field, the field
// name is used if no value argument is present.
func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
args := make([]string, 0)
for _, goTag := range fd.Directives.ForNames("goTag") {
@ -406,12 +444,120 @@ func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Fie
}
if len(args) > 0 {
f.Tag = f.Tag + " " + strings.Join(args, " ")
f.Tag = removeDuplicateTags(f.Tag + " " + strings.Join(args, " "))
}
return f, nil
}
// splitTagsBySpace split tags by space, except when space is inside quotes
func splitTagsBySpace(tagsString string) []string {
var tags []string
var currentTag string
inQuotes := false
for _, c := range tagsString {
if c == '"' {
inQuotes = !inQuotes
}
if c == ' ' && !inQuotes {
tags = append(tags, currentTag)
currentTag = ""
} else {
currentTag += string(c)
}
}
tags = append(tags, currentTag)
return tags
}
// containsInvalidSpace checks if the tagsString contains invalid space
func containsInvalidSpace(valuesString string) bool {
// get rid of quotes
valuesString = strings.ReplaceAll(valuesString, "\"", "")
if strings.Contains(valuesString, ",") {
// split by comma,
values := strings.Split(valuesString, ",")
for _, value := range values {
if strings.TrimSpace(value) != value {
return true
}
}
return false
}
if strings.Contains(valuesString, ";") {
// split by semicolon, which is common in gorm
values := strings.Split(valuesString, ";")
for _, value := range values {
if strings.TrimSpace(value) != value {
return true
}
}
return false
}
// single value
if strings.TrimSpace(valuesString) != valuesString {
return true
}
return false
}
func removeDuplicateTags(t string) string {
processed := make(map[string]bool)
tt := splitTagsBySpace(t)
returnTags := ""
// iterate backwards through tags so appended goTag directives are prioritized
for i := len(tt) - 1; i >= 0; i-- {
ti := tt[i]
// check if ti contains ":", and not contains any empty space. if not, tag is in wrong format
// correct example: json:"name"
if !strings.Contains(ti, ":") {
panic(fmt.Errorf("wrong format of tags: %s. goTag directive should be in format: @goTag(key: \"something\", value:\"value\"), ", t))
}
kv := strings.Split(ti, ":")
if len(kv) == 0 || processed[kv[0]] {
continue
}
processed[kv[0]] = true
if len(returnTags) > 0 {
returnTags = " " + returnTags
}
isContained := containsInvalidSpace(kv[1])
if isContained {
panic(fmt.Errorf("tag value should not contain any leading or trailing spaces: %s", kv[1]))
}
returnTags = kv[0] + ":" + kv[1] + returnTags
}
return returnTags
}
// GoFieldHook applies the goField directive to the generated Field f.
func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
args := make([]string, 0)
_ = args
for _, goField := range fd.Directives.ForNames("goField") {
if arg := goField.Arguments.ForName("name"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.GoName = k.(string)
}
}
if arg := goField.Arguments.ForName("omittable"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.Omittable = k.(bool)
}
}
}
return f, nil
}
func isStruct(t types.Type) bool {
_, is := t.Underlying().(*types.Struct)
return is

View File

@ -29,3 +29,8 @@ type EarlySourceInjector interface {
type LateSourceInjector interface {
InjectSourceLate(schema *ast.Schema) *ast.Source
}
// Implementer is used to generate code inside resolvers
type ResolverImplementer interface {
Implement(field *codegen.Field) string
}

View File

@ -4,18 +4,21 @@ import (
_ "embed"
"errors"
"fmt"
"go/ast"
"io/fs"
"os"
"path/filepath"
"strings"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"github.com/99designs/gqlgen/codegen"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/internal/rewrite"
"github.com/99designs/gqlgen/plugin"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
//go:embed resolver.gotpl
@ -52,7 +55,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
file := File{}
if _, err := os.Stat(data.Config.Resolver.Filename); err == nil {
// file already exists and we dont support updating resolvers with layout = single so just return
// file already exists and we do not support updating resolvers with layout = single so just return
return nil
}
@ -65,16 +68,17 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
continue
}
resolver := Resolver{o, f, "// foo", `panic("not implemented")`}
resolver := Resolver{o, f, nil, "", `panic("not implemented")`}
file.Resolvers = append(file.Resolvers, &resolver)
}
}
resolverBuild := &ResolverBuild{
File: &file,
PackageName: data.Config.Resolver.Package,
ResolverType: data.Config.Resolver.Type,
HasRoot: true,
File: &file,
PackageName: data.Config.Resolver.Package,
ResolverType: data.Config.Resolver.Type,
HasRoot: true,
OmitTemplateComment: data.Config.Resolver.OmitTemplateComment,
}
return templates.Render(templates.Options{
@ -117,16 +121,30 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
}
structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type)
implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName))
comment := strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment(structName, f.GoFieldName), `\`))
implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName))
if implementation == "" {
implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name)
}
if comment == "" {
comment = fmt.Sprintf("%v is the resolver for the %v field.", f.GoFieldName, f.Name)
// Check for Implementer Plugin
var resolver_implementer plugin.ResolverImplementer
var exists bool
for _, p := range data.Plugins {
if p_cast, ok := p.(plugin.ResolverImplementer); ok {
resolver_implementer = p_cast
exists = true
break
}
}
if exists {
implementation = resolver_implementer.Implement(f)
} else {
implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name)
}
}
resolver := Resolver{o, f, comment, implementation}
resolver := Resolver{o, f, rewriter.GetPrevDecl(structName, f.GoFieldName), comment, implementation}
fn := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name, data.Config.Resolver.FilenameTemplate)
if files[fn] == nil {
files[fn] = &File{}
@ -143,20 +161,32 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
for filename, file := range files {
resolverBuild := &ResolverBuild{
File: file,
PackageName: data.Config.Resolver.Package,
ResolverType: data.Config.Resolver.Type,
File: file,
PackageName: data.Config.Resolver.Package,
ResolverType: data.Config.Resolver.Type,
OmitTemplateComment: data.Config.Resolver.OmitTemplateComment,
}
var fileNotice strings.Builder
if !data.Config.OmitGQLGenFileNotice {
fileNotice.WriteString(`
// This file will be automatically regenerated based on the schema, any resolver implementations
// will be copied through when generating and any unknown code will be moved to the end.
// Code generated by github.com/99designs/gqlgen`,
)
if !data.Config.OmitGQLGenVersionInFileNotice {
fileNotice.WriteString(` version `)
fileNotice.WriteString(graphql.Version)
}
}
err := templates.Render(templates.Options{
PackageName: data.Config.Resolver.Package,
FileNotice: `
// This file will be automatically regenerated based on the schema, any resolver implementations
// will be copied through when generating and any unknown code will be moved to the end.`,
Filename: filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
FileNotice: fileNotice.String(),
Filename: filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
})
if err != nil {
return err
@ -184,9 +214,10 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
type ResolverBuild struct {
*File
HasRoot bool
PackageName string
ResolverType string
HasRoot bool
PackageName string
ResolverType string
OmitTemplateComment bool
}
type File struct {
@ -212,6 +243,7 @@ func (f *File) Imports() string {
type Resolver struct {
Object *codegen.Object
Field *codegen.Field
PrevDecl *ast.FuncDecl
Comment string
Implementation string
}

View File

@ -19,15 +19,21 @@
{{ end }}
{{ range $resolver := .Resolvers -}}
// {{ $resolver.Comment }}
func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ $resolver.Field.ShortResolverDeclaration }} {
{{ if $resolver.Comment -}}
// {{ $resolver.Comment }}
{{- else if not $.OmitTemplateComment -}}
// {{ $resolver.Field.GoFieldName }} is the resolver for the {{ $resolver.Field.Name }} field.
{{- end }}
func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ with $resolver.PrevDecl }}{{ $resolver.Field.ShortResolverSignature .Type }}{{ else }}{{ $resolver.Field.ShortResolverDeclaration }}{{ end }}{
{{ $resolver.Implementation }}
}
{{ end }}
{{ range $object := .Objects -}}
// {{ucFirst $object.Name}} returns {{ $object.ResolverInterface | ref }} implementation.
{{ if not $.OmitTemplateComment -}}
// {{ucFirst $object.Name}} returns {{ $object.ResolverInterface | ref }} implementation.
{{- end }}
func (r *{{$.ResolverType}}) {{ucFirst $object.Name}}() {{ $object.ResolverInterface | ref }} { return &{{lcFirst $object.Name}}{{ucFirst $.ResolverType}}{r} }
{{ end }}

View File

@ -66,13 +66,12 @@ type EnforceContext struct {
//
// File:
//
// e := casbin.NewEnforcer("path/to/basic_model.conf", "path/to/basic_policy.csv")
// e := casbin.NewEnforcer("path/to/basic_model.conf", "path/to/basic_policy.csv")
//
// MySQL DB:
//
// a := mysqladapter.NewDBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/")
// e := casbin.NewEnforcer("path/to/basic_model.conf", a)
//
// a := mysqladapter.NewDBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/")
// e := casbin.NewEnforcer("path/to/basic_model.conf", a)
func NewEnforcer(params ...interface{}) (*Enforcer, error) {
e := &Enforcer{logger: &log.DefaultLogger{}}
@ -276,11 +275,13 @@ func (e *Enforcer) GetNamedRoleManager(ptype string) rbac.RoleManager {
// SetRoleManager sets the current role manager.
func (e *Enforcer) SetRoleManager(rm rbac.RoleManager) {
e.invalidateMatcherMap()
e.rmMap["g"] = rm
}
// SetNamedRoleManager sets the role manager for the named policy.
func (e *Enforcer) SetNamedRoleManager(ptype string, rm rbac.RoleManager) {
e.invalidateMatcherMap()
e.rmMap[ptype] = rm
}
@ -291,6 +292,8 @@ func (e *Enforcer) SetEffector(eft effector.Effector) {
// ClearPolicy clears all policy.
func (e *Enforcer) ClearPolicy() {
e.invalidateMatcherMap()
if e.dispatcher != nil && e.autoNotifyDispatcher {
_ = e.dispatcher.ClearPolicy()
return
@ -300,6 +303,8 @@ func (e *Enforcer) ClearPolicy() {
// LoadPolicy reloads the policy from file/database.
func (e *Enforcer) LoadPolicy() error {
e.invalidateMatcherMap()
needToRebuild := false
newModel := e.model.Copy()
newModel.ClearPolicy()
@ -343,6 +348,8 @@ func (e *Enforcer) LoadPolicy() error {
}
func (e *Enforcer) loadFilteredPolicy(filter interface{}) error {
e.invalidateMatcherMap()
var filteredAdapter persist.FilteredAdapter
// Attempt to cast the Adapter as a FilteredAdapter
@ -422,6 +429,10 @@ func (e *Enforcer) initRmMap() {
_ = rm.Clear()
} else {
e.rmMap[ptype] = defaultrolemanager.NewRoleManager(10)
matchFun := "keyMatch(r_dom, p_dom)"
if strings.Contains(e.model["m"]["m"].Value, matchFun) {
e.AddNamedDomainMatchingFunc(ptype, "g", util.KeyMatch)
}
}
}
}
@ -817,17 +828,17 @@ func (p enforceParameters) Get(name string) (interface{}, error) {
func generateEvalFunction(functions map[string]govaluate.ExpressionFunction, parameters *enforceParameters) govaluate.ExpressionFunction {
return func(args ...interface{}) (interface{}, error) {
if len(args) != 1 {
return nil, fmt.Errorf("Function eval(subrule string) expected %d arguments, but got %d", 1, len(args))
return nil, fmt.Errorf("function eval(subrule string) expected %d arguments, but got %d", 1, len(args))
}
expression, ok := args[0].(string)
if !ok {
return nil, errors.New("Argument of eval(subrule string) must be a string")
return nil, errors.New("argument of eval(subrule string) must be a string")
}
expression = util.EscapeAssertion(expression)
expr, err := govaluate.NewEvaluableExpressionWithFunctions(expression, functions)
if err != nil {
return nil, fmt.Errorf("Error while parsing eval parameter: %s, %s", expression, err.Error())
return nil, fmt.Errorf("error while parsing eval parameter: %s, %s", expression, err.Error())
}
return expr.Eval(parameters)
}

View File

@ -18,6 +18,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/casbin/casbin/v2/persist/cache"
)
@ -25,7 +26,7 @@ import (
// CachedEnforcer wraps Enforcer and provides decision cache
type CachedEnforcer struct {
*Enforcer
expireTime uint
expireTime time.Duration
cache cache.Cache
enableCache int32
locker *sync.RWMutex
@ -45,8 +46,7 @@ func NewCachedEnforcer(params ...interface{}) (*CachedEnforcer, error) {
}
e.enableCache = 1
cache := cache.DefaultCache(make(map[string]bool))
e.cache = &cache
e.cache, _ = cache.NewDefaultCache()
e.locker = new(sync.RWMutex)
return e, nil
}
@ -132,7 +132,7 @@ func (e *CachedEnforcer) getCachedResult(key string) (res bool, err error) {
return e.cache.Get(key)
}
func (e *CachedEnforcer) SetExpireTime(expireTime uint) {
func (e *CachedEnforcer) SetExpireTime(expireTime time.Duration) {
e.expireTime = expireTime
}
@ -147,6 +147,17 @@ func (e *CachedEnforcer) setCachedResult(key string, res bool, extra ...interfac
}
func (e *CachedEnforcer) getKey(params ...interface{}) (string, bool) {
return GetCacheKey(params...)
}
// InvalidateCache deletes all the existing cached decisions.
func (e *CachedEnforcer) InvalidateCache() error {
e.locker.Lock()
defer e.locker.Unlock()
return e.cache.Clear()
}
func GetCacheKey(params ...interface{}) (string, bool) {
key := strings.Builder{}
for _, param := range params {
switch typedParam := param.(type) {
@ -160,11 +171,4 @@ func (e *CachedEnforcer) getKey(params ...interface{}) (string, bool) {
key.WriteString("$$")
}
return key.String(), true
}
// InvalidateCache deletes all the existing cached decisions.
func (e *CachedEnforcer) InvalidateCache() error {
e.locker.Lock()
defer e.locker.Unlock()
return e.cache.Clear()
}
}

View File

@ -0,0 +1,180 @@
// Copyright 2018 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package casbin
import (
"sync"
"sync/atomic"
"time"
"github.com/casbin/casbin/v2/persist/cache"
)
// SyncedCachedEnforcer wraps Enforcer and provides decision sync cache
type SyncedCachedEnforcer struct {
*SyncedEnforcer
expireTime time.Duration
cache cache.Cache
enableCache int32
locker *sync.RWMutex
}
// NewSyncedCachedEnforcer creates a sync cached enforcer via file or DB.
func NewSyncedCachedEnforcer(params ...interface{}) (*SyncedCachedEnforcer, error) {
e := &SyncedCachedEnforcer{}
var err error
e.SyncedEnforcer, err = NewSyncedEnforcer(params...)
if err != nil {
return nil, err
}
e.enableCache = 1
e.cache, _ = cache.NewSyncCache()
e.locker = new(sync.RWMutex)
return e, nil
}
// EnableCache determines whether to enable cache on Enforce(). When enableCache is enabled, cached result (true | false) will be returned for previous decisions.
func (e *SyncedCachedEnforcer) EnableCache(enableCache bool) {
var enabled int32
if enableCache {
enabled = 1
}
atomic.StoreInt32(&e.enableCache, enabled)
}
// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act).
// if rvals is not string , ingore the cache
func (e *SyncedCachedEnforcer) Enforce(rvals ...interface{}) (bool, error) {
if atomic.LoadInt32(&e.enableCache) == 0 {
return e.SyncedEnforcer.Enforce(rvals...)
}
key, ok := e.getKey(rvals...)
if !ok {
return e.SyncedEnforcer.Enforce(rvals...)
}
if res, err := e.getCachedResult(key); err == nil {
return res, nil
} else if err != cache.ErrNoSuchKey {
return res, err
}
res, err := e.SyncedEnforcer.Enforce(rvals...)
if err != nil {
return false, err
}
err = e.setCachedResult(key, res, e.expireTime)
return res, err
}
func (e *SyncedCachedEnforcer) LoadPolicy() error {
if atomic.LoadInt32(&e.enableCache) != 0 {
if err := e.cache.Clear(); err != nil {
return err
}
}
return e.SyncedEnforcer.LoadPolicy()
}
func (e *SyncedCachedEnforcer) AddPolicy(params ...interface{}) (bool, error) {
if ok, err := e.checkOneAndRemoveCache(params...); !ok {
return ok, err
}
return e.SyncedEnforcer.AddPolicy(params...)
}
func (e *SyncedCachedEnforcer) AddPolicies(rules [][]string) (bool, error) {
if ok, err := e.checkManyAndRemoveCache(rules); !ok {
return ok, err
}
return e.SyncedEnforcer.AddPolicies(rules)
}
func (e *SyncedCachedEnforcer) RemovePolicy(params ...interface{}) (bool, error) {
if ok, err := e.checkOneAndRemoveCache(params...); !ok {
return ok, err
}
return e.SyncedEnforcer.RemovePolicy(params...)
}
func (e *SyncedCachedEnforcer) RemovePolicies(rules [][]string) (bool, error) {
if ok, err := e.checkManyAndRemoveCache(rules); !ok {
return ok, err
}
return e.SyncedEnforcer.RemovePolicies(rules)
}
func (e *SyncedCachedEnforcer) getCachedResult(key string) (res bool, err error) {
return e.cache.Get(key)
}
func (e *SyncedCachedEnforcer) SetExpireTime(expireTime time.Duration) {
e.locker.Lock()
defer e.locker.Unlock()
e.expireTime = expireTime
}
// SetCache need to be sync cache
func (e *SyncedCachedEnforcer) SetCache(c cache.Cache) {
e.locker.Lock()
defer e.locker.Unlock()
e.cache = c
}
func (e *SyncedCachedEnforcer) setCachedResult(key string, res bool, extra ...interface{}) error {
return e.cache.Set(key, res, extra...)
}
func (e *SyncedCachedEnforcer) getKey(params ...interface{}) (string, bool) {
return GetCacheKey(params...)
}
// InvalidateCache deletes all the existing cached decisions.
func (e *SyncedCachedEnforcer) InvalidateCache() error {
return e.cache.Clear()
}
func (e *SyncedCachedEnforcer) checkOneAndRemoveCache(params ...interface{}) (bool, error) {
if atomic.LoadInt32(&e.enableCache) != 0 {
key, ok := e.getKey(params...)
if ok {
if err := e.cache.Delete(key); err != nil && err != cache.ErrNoSuchKey {
return false, err
}
}
}
return true, nil
}
func (e *SyncedCachedEnforcer) checkManyAndRemoveCache(rules [][]string) (bool, error) {
if len(rules) != 0 {
if atomic.LoadInt32(&e.enableCache) != 0 {
irule := make([]interface{}, len(rules[0]))
for _, rule := range rules {
for i, param := range rule {
irule[i] = param
}
key, _ := e.getKey(irule...)
if err := e.cache.Delete(key); err != nil && err != cache.ErrNoSuchKey {
return false, err
}
}
}
}
return true, nil
}

View File

@ -110,6 +110,8 @@ type IEnforcer interface {
AddPolicies(rules [][]string) (bool, error)
AddNamedPolicy(ptype string, params ...interface{}) (bool, error)
AddNamedPolicies(ptype string, rules [][]string) (bool, error)
AddPoliciesEx(rules [][]string) (bool, error)
AddNamedPoliciesEx(ptype string, rules [][]string) (bool, error)
RemovePolicy(params ...interface{}) (bool, error)
RemovePolicies(rules [][]string) (bool, error)
RemoveFilteredPolicy(fieldIndex int, fieldValues ...string) (bool, error)
@ -120,8 +122,10 @@ type IEnforcer interface {
HasNamedGroupingPolicy(ptype string, params ...interface{}) bool
AddGroupingPolicy(params ...interface{}) (bool, error)
AddGroupingPolicies(rules [][]string) (bool, error)
AddGroupingPoliciesEx(rules [][]string) (bool, error)
AddNamedGroupingPolicy(ptype string, params ...interface{}) (bool, error)
AddNamedGroupingPolicies(ptype string, rules [][]string) (bool, error)
AddNamedGroupingPoliciesEx(ptype string, rules [][]string) (bool, error)
RemoveGroupingPolicy(params ...interface{}) (bool, error)
RemoveGroupingPolicies(rules [][]string) (bool, error)
RemoveFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) (bool, error)
@ -136,10 +140,13 @@ type IEnforcer interface {
UpdateGroupingPolicy(oldRule []string, newRule []string) (bool, error)
UpdateGroupingPolicies(oldRules [][]string, newRules [][]string) (bool, error)
UpdateNamedGroupingPolicy(ptype string, oldRule []string, newRule []string) (bool, error)
UpdateNamedGroupingPolicies(ptype string, oldRules [][]string, newRules [][]string) (bool, error)
/* Management API with autoNotifyWatcher disabled */
SelfAddPolicy(sec string, ptype string, rule []string) (bool, error)
SelfAddPolicies(sec string, ptype string, rules [][]string) (bool, error)
SelfAddPoliciesEx(sec string, ptype string, rules [][]string) (bool, error)
SelfRemovePolicy(sec string, ptype string, rule []string) (bool, error)
SelfRemovePolicies(sec string, ptype string, rules [][]string) (bool, error)
SelfRemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (bool, error)

View File

@ -376,6 +376,15 @@ func (e *SyncedEnforcer) AddPolicies(rules [][]string) (bool, error) {
return e.Enforcer.AddPolicies(rules)
}
// AddPoliciesEx adds authorization rules to the current policy.
// If the rule already exists, the rule will not be added.
// But unlike AddPolicies, other non-existent rules are added instead of returning false directly
func (e *SyncedEnforcer) AddPoliciesEx(rules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.AddPoliciesEx(rules)
}
// AddNamedPolicy adds an authorization rule to the current named policy.
// If the rule already exists, the function returns false and the rule will not be added.
// Otherwise the function returns true by adding the new rule.
@ -394,6 +403,15 @@ func (e *SyncedEnforcer) AddNamedPolicies(ptype string, rules [][]string) (bool,
return e.Enforcer.AddNamedPolicies(ptype, rules)
}
// AddNamedPoliciesEx adds authorization rules to the current named policy.
// If the rule already exists, the rule will not be added.
// But unlike AddNamedPolicies, other non-existent rules are added instead of returning false directly
func (e *SyncedEnforcer) AddNamedPoliciesEx(ptype string, rules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.AddNamedPoliciesEx(ptype, rules)
}
// RemovePolicy removes an authorization rule from the current policy.
func (e *SyncedEnforcer) RemovePolicy(params ...interface{}) (bool, error) {
e.m.Lock()
@ -506,6 +524,15 @@ func (e *SyncedEnforcer) AddGroupingPolicies(rules [][]string) (bool, error) {
return e.Enforcer.AddGroupingPolicies(rules)
}
// AddGroupingPoliciesEx adds role inheritance rules to the current policy.
// If the rule already exists, the rule will not be added.
// But unlike AddGroupingPolicies, other non-existent rules are added instead of returning false directly
func (e *SyncedEnforcer) AddGroupingPoliciesEx(rules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.AddGroupingPoliciesEx(rules)
}
// AddNamedGroupingPolicy adds a named role inheritance rule to the current policy.
// If the rule already exists, the function returns false and the rule will not be added.
// Otherwise the function returns true by adding the new rule.
@ -524,6 +551,15 @@ func (e *SyncedEnforcer) AddNamedGroupingPolicies(ptype string, rules [][]string
return e.Enforcer.AddNamedGroupingPolicies(ptype, rules)
}
// AddNamedGroupingPoliciesEx adds named role inheritance rules to the current policy.
// If the rule already exists, the rule will not be added.
// But unlike AddNamedGroupingPolicies, other non-existent rules are added instead of returning false directly
func (e *SyncedEnforcer) AddNamedGroupingPoliciesEx(ptype string, rules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.AddNamedGroupingPoliciesEx(ptype, rules)
}
// RemoveGroupingPolicy removes a role inheritance rule from the current policy.
func (e *SyncedEnforcer) RemoveGroupingPolicy(params ...interface{}) (bool, error) {
e.m.Lock()
@ -596,3 +632,51 @@ func (e *SyncedEnforcer) AddFunction(name string, function govaluate.ExpressionF
defer e.m.Unlock()
e.Enforcer.AddFunction(name, function)
}
func (e *SyncedEnforcer) SelfAddPolicy(sec string, ptype string, rule []string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfAddPolicy(sec, ptype, rule)
}
func (e *SyncedEnforcer) SelfAddPolicies(sec string, ptype string, rules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfAddPolicies(sec, ptype, rules)
}
func (e *SyncedEnforcer) SelfAddPoliciesEx(sec string, ptype string, rules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfAddPoliciesEx(sec, ptype, rules)
}
func (e *SyncedEnforcer) SelfRemovePolicy(sec string, ptype string, rule []string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfRemovePolicy(sec, ptype, rule)
}
func (e *SyncedEnforcer) SelfRemovePolicies(sec string, ptype string, rules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfRemovePolicies(sec, ptype, rules)
}
func (e *SyncedEnforcer) SelfRemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfRemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...)
}
func (e *SyncedEnforcer) SelfUpdatePolicy(sec string, ptype string, oldRule, newRule []string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfUpdatePolicy(sec, ptype, oldRule, newRule)
}
func (e *SyncedEnforcer) SelfUpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) (bool, error) {
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.SelfUpdatePolicies(sec, ptype, oldRules, newRules)
}

View File

@ -18,9 +18,13 @@ import "errors"
// Global errors for rbac defined here
var (
ERR_NAME_NOT_FOUND = errors.New("error: name does not exist")
ERR_DOMAIN_PARAMETER = errors.New("error: domain should be 1 parameter")
ERR_LINK_NOT_FOUND = errors.New("error: link between name1 and name2 does not exist")
ERR_USE_DOMAIN_PARAMETER = errors.New("error: useDomain should be 1 parameter")
INVALID_FIELDVAULES_PARAMETER = errors.New("fieldValues requires at least one parameter")
ErrNameNotFound = errors.New("error: name does not exist")
ErrDomainParameter = errors.New("error: domain should be 1 parameter")
ErrLinkNotFound = errors.New("error: link between name1 and name2 does not exist")
ErrUseDomainParameter = errors.New("error: useDomain should be 1 parameter")
ErrInvalidFieldValuesParameter = errors.New("fieldValues requires at least one parameter")
// GetAllowedObjectConditions errors
ErrObjCondition = errors.New("need to meet the prefix required by the object condition")
ErrEmptyCondition = errors.New("GetAllowedObjectConditions have an empty condition")
)

View File

@ -64,13 +64,15 @@ func (e *Enforcer) addPolicyWithoutNotify(sec string, ptype string, rule []strin
return true, nil
}
// addPolicies adds rules to the current policy.
func (e *Enforcer) addPoliciesWithoutNotify(sec string, ptype string, rules [][]string) (bool, error) {
// addPoliciesWithoutNotify adds rules to the current policy without notify
// If autoRemoveRepeat == true, existing rules are automatically filtered
// Otherwise, false is returned directly
func (e *Enforcer) addPoliciesWithoutNotify(sec string, ptype string, rules [][]string, autoRemoveRepeat bool) (bool, error) {
if e.dispatcher != nil && e.autoNotifyDispatcher {
return true, e.dispatcher.AddPolicies(sec, ptype, rules)
}
if e.model.HasPolicies(sec, ptype, rules) {
if !autoRemoveRepeat && e.model.HasPolicies(sec, ptype, rules) {
return false, nil
}
@ -225,7 +227,7 @@ func (e *Enforcer) removePoliciesWithoutNotify(sec string, ptype string, rules [
// removeFilteredPolicy removes rules based on field filters from the current policy.
func (e *Enforcer) removeFilteredPolicyWithoutNotify(sec string, ptype string, fieldIndex int, fieldValues []string) (bool, error) {
if len(fieldValues) == 0 {
return false, Err.INVALID_FIELDVAULES_PARAMETER
return false, Err.ErrInvalidFieldValuesParameter
}
if e.dispatcher != nil && e.autoNotifyDispatcher {
@ -321,8 +323,10 @@ func (e *Enforcer) addPolicy(sec string, ptype string, rule []string) (bool, err
}
// addPolicies adds rules to the current policy.
func (e *Enforcer) addPolicies(sec string, ptype string, rules [][]string) (bool, error) {
ok, err := e.addPoliciesWithoutNotify(sec, ptype, rules)
// If autoRemoveRepeat == true, existing rules are automatically filtered
// Otherwise, false is returned directly
func (e *Enforcer) addPolicies(sec string, ptype string, rules [][]string, autoRemoveRepeat bool) (bool, error) {
ok, err := e.addPoliciesWithoutNotify(sec, ptype, rules, autoRemoveRepeat)
if !ok || err != nil {
return ok, err
}

View File

@ -207,6 +207,13 @@ func (e *Enforcer) AddPolicies(rules [][]string) (bool, error) {
return e.AddNamedPolicies("p", rules)
}
// AddPoliciesEx adds authorization rules to the current policy.
// If the rule already exists, the rule will not be added.
// But unlike AddPolicies, other non-existent rules are added instead of returning false directly
func (e *Enforcer) AddPoliciesEx(rules [][]string) (bool, error) {
return e.AddNamedPoliciesEx("p", rules)
}
// AddNamedPolicy adds an authorization rule to the current named policy.
// If the rule already exists, the function returns false and the rule will not be added.
// Otherwise the function returns true by adding the new rule.
@ -227,7 +234,14 @@ func (e *Enforcer) AddNamedPolicy(ptype string, params ...interface{}) (bool, er
// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added.
// Otherwise the function returns true for the corresponding by adding the new rule.
func (e *Enforcer) AddNamedPolicies(ptype string, rules [][]string) (bool, error) {
return e.addPolicies("p", ptype, rules)
return e.addPolicies("p", ptype, rules, false)
}
// AddNamedPoliciesEx adds authorization rules to the current named policy.
// If the rule already exists, the rule will not be added.
// But unlike AddNamedPolicies, other non-existent rules are added instead of returning false directly
func (e *Enforcer) AddNamedPoliciesEx(ptype string, rules [][]string) (bool, error) {
return e.addPolicies("p", ptype, rules, true)
}
// RemovePolicy removes an authorization rule from the current policy.
@ -327,6 +341,13 @@ func (e *Enforcer) AddGroupingPolicies(rules [][]string) (bool, error) {
return e.AddNamedGroupingPolicies("g", rules)
}
// AddGroupingPoliciesEx adds role inheritance rules to the current policy.
// If the rule already exists, the rule will not be added.
// But unlike AddGroupingPolicies, other non-existent rules are added instead of returning false directly
func (e *Enforcer) AddGroupingPoliciesEx(rules [][]string) (bool, error) {
return e.AddNamedGroupingPoliciesEx("g", rules)
}
// AddNamedGroupingPolicy adds a named role inheritance rule to the current policy.
// If the rule already exists, the function returns false and the rule will not be added.
// Otherwise the function returns true by adding the new rule.
@ -351,7 +372,14 @@ func (e *Enforcer) AddNamedGroupingPolicy(ptype string, params ...interface{}) (
// If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
// Otherwise the function returns true for the corresponding policy rule by adding the new rule.
func (e *Enforcer) AddNamedGroupingPolicies(ptype string, rules [][]string) (bool, error) {
return e.addPolicies("g", ptype, rules)
return e.addPolicies("g", ptype, rules, false)
}
// AddNamedGroupingPoliciesEx adds named role inheritance rules to the current policy.
// If the rule already exists, the rule will not be added.
// But unlike AddNamedGroupingPolicies, other non-existent rules are added instead of returning false directly
func (e *Enforcer) AddNamedGroupingPoliciesEx(ptype string, rules [][]string) (bool, error) {
return e.addPolicies("g", ptype, rules, true)
}
// RemoveGroupingPolicy removes a role inheritance rule from the current policy.
@ -424,7 +452,11 @@ func (e *Enforcer) SelfAddPolicy(sec string, ptype string, rule []string) (bool,
}
func (e *Enforcer) SelfAddPolicies(sec string, ptype string, rules [][]string) (bool, error) {
return e.addPoliciesWithoutNotify(sec, ptype, rules)
return e.addPoliciesWithoutNotify(sec, ptype, rules, false)
}
func (e *Enforcer) SelfAddPoliciesEx(sec string, ptype string, rules [][]string) (bool, error) {
return e.addPoliciesWithoutNotify(sec, ptype, rules, true)
}
func (e *Enforcer) SelfRemovePolicy(sec string, ptype string, rule []string) (bool, error) {

View File

@ -20,7 +20,7 @@ var ErrNoSuchKey = errors.New("there's no such key existing in cache")
type Cache interface {
// Set puts key and value into cache.
// First parameter for extra should be uint denoting expected survival time.
// First parameter for extra should be time.Time object denoting expected survival time.
// If survival time equals 0 or less, the key will always be survival.
Set(key string, value bool, extra ...interface{}) error

View File

@ -0,0 +1,86 @@
// Copyright 2021 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"sync"
"time"
)
type SyncCache struct {
cache DefaultCache
sync.RWMutex
}
func (c *SyncCache) Set(key string, value bool, extra ...interface{}) error {
ttl := time.Duration(-1)
if len(extra) > 0 {
ttl = extra[0].(time.Duration)
}
c.Lock()
defer c.Unlock()
c.cache[key] = cacheItem{
value: value,
expiresAt: time.Now().Add(ttl),
ttl: ttl,
}
return nil
}
func (c *SyncCache) Get(key string) (bool, error) {
c.RLock()
res, ok := c.cache[key]
c.RUnlock()
if !ok {
return false, ErrNoSuchKey
} else {
if res.ttl > 0 && time.Now().After(res.expiresAt) {
c.Lock()
defer c.Unlock()
delete(c.cache, key)
return false, ErrNoSuchKey
}
return res.value, nil
}
}
func (c *SyncCache) Delete(key string) error {
c.RLock()
_, ok := c.cache[key]
c.RUnlock()
if !ok {
return ErrNoSuchKey
} else {
c.Lock()
defer c.Unlock()
delete(c.cache, key)
return nil
}
}
func (c *SyncCache) Clear() error {
c.Lock()
c.cache = make(DefaultCache)
c.Unlock()
return nil
}
func NewSyncCache() (Cache, error) {
cache := SyncCache{
make(DefaultCache),
sync.RWMutex{},
}
return &cache, nil
}

View File

@ -14,10 +14,26 @@
package cache
type DefaultCache map[string]bool
import "time"
type cacheItem struct {
value bool
expiresAt time.Time
ttl time.Duration
}
type DefaultCache map[string]cacheItem
func (c *DefaultCache) Set(key string, value bool, extra ...interface{}) error {
(*c)[key] = value
ttl := time.Duration(-1)
if len(extra) > 0 {
ttl = extra[0].(time.Duration)
}
(*c)[key] = cacheItem{
value: value,
expiresAt: time.Now().Add(ttl),
ttl: ttl,
}
return nil
}
@ -25,7 +41,11 @@ func (c *DefaultCache) Get(key string) (bool, error) {
if res, ok := (*c)[key]; !ok {
return false, ErrNoSuchKey
} else {
return res, nil
if res.ttl > 0 && time.Now().After(res.expiresAt) {
delete(*c, key)
return false, ErrNoSuchKey
}
return res.value, nil
}
}
@ -42,3 +62,8 @@ func (c *DefaultCache) Clear() error {
*c = make(DefaultCache)
return nil
}
func NewDefaultCache() (Cache, error) {
cache := make(DefaultCache)
return &cache, nil
}

View File

@ -507,7 +507,7 @@ func (dm *DomainManager) getDomain(domains ...string) (domain string, err error)
case 1:
return domains[0], nil
default:
return "", errors.ERR_DOMAIN_PARAMETER
return "", errors.ErrDomainParameter
}
}

View File

@ -15,6 +15,8 @@
package casbin
import (
"strings"
"github.com/casbin/casbin/v2/constant"
"github.com/casbin/casbin/v2/errors"
"github.com/casbin/casbin/v2/util"
@ -84,7 +86,7 @@ func (e *Enforcer) DeleteRolesForUser(user string, domain ...string) (bool, erro
if len(domain) == 0 {
args = []string{user}
} else if len(domain) > 1 {
return false, errors.ERR_DOMAIN_PARAMETER
return false, errors.ErrDomainParameter
} else {
args = []string{user, "", domain[0]}
}
@ -304,7 +306,7 @@ func (e *Enforcer) GetNamedImplicitPermissionsForUser(ptype string, user string,
permission = append(permission, deepCopyPolicy(rule))
}
} else if len(domain) > 1 {
return nil, errors.ERR_DOMAIN_PARAMETER
return nil, errors.ErrDomainParameter
} else {
d := domain[0]
matched := rm.Match(d, rule[domainIndex])
@ -414,3 +416,149 @@ func deepCopyPolicy(src []string) []string {
copy(newRule, src)
return newRule
}
// GetAllowedObjectConditions returns a string array of object conditions that the user can access.
// For example: conditions, err := e.GetAllowedObjectConditions("alice", "read", "r.obj.")
// Note:
//
// 0. prefix: You can customize the prefix of the object conditions, and "r.obj." is commonly used as a prefix.
// After removing the prefix, the remaining part is the condition of the object.
// If there is an obj policy that does not meet the prefix requirement, an errors.ERR_OBJ_CONDITION will be returned.
//
// 1. If the 'objectConditions' array is empty, return errors.ERR_EMPTY_CONDITION
// This error is returned because some data adapters' ORM return full table data by default
// when they receive an empty condition, which tends to behave contrary to expectations.(e.g. GORM)
// If you are using an adapter that does not behave like this, you can choose to ignore this error.
func (e *Enforcer) GetAllowedObjectConditions(user string, action string, prefix string) ([]string, error) {
permissions, err := e.GetImplicitPermissionsForUser(user)
if err != nil {
return nil, err
}
var objectConditions []string
for _, policy := range permissions {
// policy {sub, obj, act}
if policy[2] == action {
if !strings.HasPrefix(policy[1], prefix) {
return nil, errors.ErrObjCondition
}
objectConditions = append(objectConditions, strings.TrimPrefix(policy[1], prefix))
}
}
if len(objectConditions) == 0 {
return nil, errors.ErrEmptyCondition
}
return objectConditions, nil
}
// removeDuplicatePermissions Convert permissions to string as a hash to deduplicate.
func removeDuplicatePermissions(permissions [][]string) [][]string {
permissionsSet := make(map[string]bool)
res := make([][]string, 0)
for _, permission := range permissions {
permissionStr := util.ArrayToString(permission)
if permissionsSet[permissionStr] {
continue
}
permissionsSet[permissionStr] = true
res = append(res, permission)
}
return res
}
// GetImplicitUsersForResource return implicit user based on resource.
// for example:
// p, alice, data1, read
// p, bob, data2, write
// p, data2_admin, data2, read
// p, data2_admin, data2, write
// g, alice, data2_admin
// GetImplicitUsersForResource("data2") will return [[bob data2 write] [alice data2 read] [alice data2 write]]
// GetImplicitUsersForResource("data1") will return [[alice data1 read]]
// Note: only users will be returned, roles (2nd arg in "g") will be excluded.
func (e *Enforcer) GetImplicitUsersForResource(resource string) ([][]string, error) {
permissions := make([][]string, 0)
subjectIndex, _ := e.GetFieldIndex("p", "sub")
objectIndex, _ := e.GetFieldIndex("p", "obj")
rm := e.GetRoleManager()
isRole := make(map[string]bool)
for _, role := range e.GetAllRoles() {
isRole[role] = true
}
for _, rule := range e.model["p"]["p"].Policy {
obj := rule[objectIndex]
if obj != resource {
continue
}
sub := rule[subjectIndex]
if !isRole[sub] {
permissions = append(permissions, rule)
} else {
users, err := rm.GetUsers(sub)
if err != nil {
return nil, err
}
for _, user := range users {
implicitUserRule := deepCopyPolicy(rule)
implicitUserRule[subjectIndex] = user
permissions = append(permissions, implicitUserRule)
}
}
}
res := removeDuplicatePermissions(permissions)
return res, nil
}
// GetImplicitUsersForResourceByDomain return implicit user based on resource and domain.
// Compared to GetImplicitUsersForResource, domain is supported
func (e *Enforcer) GetImplicitUsersForResourceByDomain(resource string, domain string) ([][]string, error) {
permissions := make([][]string, 0)
subjectIndex, _ := e.GetFieldIndex("p", "sub")
objectIndex, _ := e.GetFieldIndex("p", "obj")
domIndex, _ := e.GetFieldIndex("p", "dom")
rm := e.GetRoleManager()
isRole := make(map[string]bool)
for _, role := range e.GetAllRolesByDomain(domain) {
isRole[role] = true
}
for _, rule := range e.model["p"]["p"].Policy {
obj := rule[objectIndex]
if obj != resource {
continue
}
sub := rule[subjectIndex]
if !isRole[sub] {
permissions = append(permissions, rule)
} else {
if domain != rule[domIndex] {
continue
}
users, err := rm.GetUsers(sub, domain)
if err != nil {
return nil, err
}
for _, user := range users {
implicitUserRule := deepCopyPolicy(rule)
implicitUserRule[subjectIndex] = user
permissions = append(permissions, implicitUserRule)
}
}
}
res := removeDuplicatePermissions(permissions)
return res, nil
}

View File

@ -160,8 +160,8 @@ func (e *SyncedEnforcer) GetImplicitRolesForUser(name string, domain ...string)
// GetPermissionsForUser("alice") can only get: [["alice", "data2", "read"]].
// But GetImplicitPermissionsForUser("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]].
func (e *SyncedEnforcer) GetImplicitPermissionsForUser(user string, domain ...string) ([][]string, error) {
e.m.RLock()
defer e.m.RUnlock()
e.m.Lock()
defer e.m.Unlock()
return e.Enforcer.GetImplicitPermissionsForUser(user, domain...)
}

View File

@ -144,3 +144,24 @@ func (e *Enforcer) DeleteDomains(domains ...string) (bool, error) {
func (e *Enforcer) GetAllDomains() ([]string, error) {
return e.model["g"]["g"].RM.GetAllDomains()
}
// GetAllRolesByDomain would get all roles associated with the domain.
// note: Not applicable to Domains with inheritance relationship (implicit roles)
func (e *Enforcer) GetAllRolesByDomain(domain string) []string {
g := e.model["g"]["g"]
policies := g.Policy
roles := make([]string, 0)
existMap := make(map[string]bool) // remove duplicates
for _, policy := range policies {
if policy[len(policy)-1] == domain {
role := policy[len(policy)-2]
if _, ok := existMap[role]; !ok {
roles = append(roles, role)
existMap[role] = true
}
}
}
return roles
}

View File

@ -34,13 +34,13 @@ var (
// validate the variadic parameter size and type as string
func validateVariadicArgs(expectedLen int, args ...interface{}) error {
if len(args) != expectedLen {
return fmt.Errorf("Expected %d arguments, but got %d", expectedLen, len(args))
return fmt.Errorf("expected %d arguments, but got %d", expectedLen, len(args))
}
for _, p := range args {
_, ok := p.(string)
if !ok {
return errors.New("Argument must be a string")
return errors.New("argument must be a string")
}
}
@ -272,15 +272,24 @@ func KeyMatch4Func(args ...interface{}) (interface{}, error) {
return bool(KeyMatch4(name1, name2)), nil
}
// KeyMatch determines whether key1 matches the pattern of key2 and ignores the parameters in key2.
// For example, "/foo/bar?status=1&type=2" matches "/foo/bar"
// KeyMatch5 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *
// For example,
// - "/foo/bar?status=1&type=2" matches "/foo/bar"
// - "/parent/child1" and "/parent/child1" matches "/parent/*"
// - "/parent/child1?status=1" matches "/parent/*"
func KeyMatch5(key1 string, key2 string) bool {
i := strings.Index(key1, "?")
if i == -1 {
return key1 == key2
if i != -1 {
key1 = key1[:i]
}
return key1[:i] == key2
key2 = strings.Replace(key2, "/*", "/.*", -1)
re := regexp.MustCompile(`\{[^/]+\}`)
key2 = re.ReplaceAllString(key2, "$1[^/]+$2")
return RegexMatch(key1, "^"+key2+"$")
}
// KeyMatch5Func is the wrapper for KeyMatch5.

View File

@ -70,6 +70,44 @@ func Array2DEquals(a [][]string, b [][]string) bool {
return true
}
// SortArray2D Sorts the two-dimensional string array
func SortArray2D(arr [][]string) {
if len(arr) != 0 {
sort.Slice(arr, func(i, j int) bool {
elementLen := len(arr[0])
for k := 0; k < elementLen; k++ {
if arr[i][k] < arr[j][k] {
return true
} else if arr[i][k] > arr[j][k] {
return false
}
}
return true
})
}
}
// SortedArray2DEquals determines whether two 2-dimensional string arrays are identical.
func SortedArray2DEquals(a [][]string, b [][]string) bool {
if len(a) != len(b) {
return false
}
copyA := make([][]string, len(a))
copy(copyA, a)
copyB := make([][]string, len(b))
copy(copyB, b)
SortArray2D(copyA)
SortArray2D(copyB)
for i, v := range copyA {
if !ArrayEquals(v, copyB[i]) {
return false
}
}
return true
}
// ArrayRemoveDuplicates removes any duplicated elements in a string array.
func ArrayRemoveDuplicates(s *[]string) {
found := make(map[string]bool)
@ -323,8 +361,8 @@ func NewSyncLRUCache(capacity int) *SyncLRUCache {
}
func (cache *SyncLRUCache) Get(key interface{}) (value interface{}, ok bool) {
cache.rwm.RLock()
defer cache.rwm.RUnlock()
cache.rwm.Lock()
defer cache.rwm.Unlock()
return cache.LRUCache.Get(key)
}

View File

@ -7,6 +7,10 @@ all: build
test:
go test -race -coverprofile=/dev/null -v ./...
## fuzz: Run fuzz tests
fuzz:
go test -fuzz=Fuzz -run=^Fuzz ./internal/packet -fuzztime 30s
## vet: Analyze code for potential errors
vet:
go vet ./...
@ -58,7 +62,7 @@ docker:
logtopics:
grep -ERho 'log\("([^"]+)' *.go | sed -E -e 's/log\("//' | sort -u
.PHONY: help test vet fmt vendor commit coverage lint client server update logtopics
.PHONY: help test fuzz vet fmt vendor commit coverage lint client server update logtopics
## help: Show all commands
help: Makefile

View File

@ -1,5 +1,13 @@
# GoSRT
Implementation of the SRT protocol in pure Go with minimal dependencies.
<p align="left">
<a href="http://srtalliance.org/">
<img alt="SRT" src="https://github.com/datarhei/misc/blob/main/img/gosrt.png?raw=true" width="600"/>
</a>
</p>
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
![Tests](https://github.com/datarhei/gosrt/actions/workflows/go-tests.yml/badge.svg)
[![codecov](https://codecov.io/gh/datarhei/gosrt/branch/main/graph/badge.svg?token=90YMPZRAFK)](https://codecov.io/gh/datarhei/gosrt)
@ -10,11 +18,14 @@ Implementation of the SRT protocol in pure Go with minimal dependencies.
- [SRT RFC](https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html)
- [SRT Technical Overview](https://github.com/Haivision/srt/files/2489142/SRT_Protocol_TechnicalOverview_DRAFT_2018-10-17.pdf)
## Implementations
This implementation of the SRT protocol has live streaming of video/audio in mind. Because of this, the buffer mode and File Transfer
Congestion Control (FileCC) are not implemented.
| | |
| --- | ----------------------------------------- |
| ✅ | Handshake v4 and v5 |
| ✅ | Message mode |
| ✅ | Caller-Listener Handshake |
| ✅ | Timestamp-Based Packet Delivery (TSBPD) |
@ -29,17 +40,17 @@ Congestion Control (FileCC) are not implemented.
The parts that are implemented are based on what has been published in the SRT RFC.
# Requirements
## Requirements
A Go version of 1.16+ is required.
# Installation
## Installation
```
go get github.com/datarhei/gosrt
```
# Caller example
## Caller example
```
import "github.com/datarhei/gosrt"
@ -67,7 +78,7 @@ conn.Close()
In the `contrib/client` directory you'll find a complete example of a SRT client.
# Listener example
## Listener example
```
import "github.com/datarhei/gosrt"
@ -103,7 +114,7 @@ In the `contrib/server` directory you'll find a complete example of a SRT server
this modules provides the `Server` type which is a light framework for creating your own SRT server. The
example server is based on this type.
## PUBLISH / SUBSCRIBE
### PUBLISH / SUBSCRIBE
The `Accept` function from the `Listener` expects a function that handles the connection requests. It can
return 3 different values: `srt.PUBLISH`, `srt.SUBSCRIBE`, and `srt.REJECT`. `srt.PUBLISH` means that the
@ -111,7 +122,7 @@ server expects the caller to send data, whereas `srt.SUBSCRIBE` means that the s
the caller. This is opiniated towards a streaming server, however in your implementation of a listener
you are free to handle connections requests to your liking.
# Contributed client
## Contributed client
In the `contrib/client` directory you'll find an example implementation of a SRT client.
@ -130,7 +141,7 @@ The application requires only two options:
Both options accept an address. Valid addresses are: `-` for `stdin`, resp. `stdout`, a `srt://` address, or an `udp://` address.
## SRT URL
### SRT URL
A SRT URL is of the form `srt://[host]:[port]/?[options]` where options are in the form of a `HTTP` query string. These are the
known options (similar to [srt-live-transmit](https://github.com/Haivision/srt/blob/master/docs/apps/srt-live-transmit.md)):
@ -172,7 +183,7 @@ known options (similar to [srt-live-transmit](https://github.com/Haivision/srt/b
| `transtype` | `live` | Transmission type. Must be `live`. |
| `tsbpdmode` | `bool` | Enable timestamp-based packet delivery mode. |
## Usage
### Usage
Reading from a SRT sender and play with `ffplay`:
@ -204,7 +215,7 @@ In the third console connect to that stream and play the video with `ffplay`:
./client -from "srt://127.0.0.1:6001/?mode=caller&streamid=foobar" -to - | ffplay -f mpegts -i -
```
# Contributed server
## Contributed server
In the `contrib/server` directory you'll find an example implementation of a SRT server. This server allows you to publish
a stream that can be read by many clients.
@ -237,7 +248,7 @@ Use `-logtopics` in order to write debug output. The value are a comma separated
Use `-profile` in order to write a CPU profile.
## StreamID
### StreamID
In SRT the StreamID is used to transport somewhat arbitrary information from the caller to the listener. The provided example server uses this
machanism to decide who is the sender and who is the receiver. The server must know if the connecting client wants to publish a stream or
@ -249,7 +260,7 @@ receive data.
If you implement your own server you are free to interpret the streamID as you wish.
## Usage
### Usage
Running a server listening on port 6001 with defaults:
@ -280,7 +291,7 @@ ffplay -f mpegts -transtype live -i "srt://127.0.0.1:6001?streamid=/live/stream"
You will most likely first see some error messages from `ffplay` because it tries to make sense of the received data until a keyframe arrives. If you
get more errors during playback, you might increase the receive buffer by adding e.g. `-rcvlatency 1000000` to the command line.
## Encryption
### Encryption
The stream can be encrypted with a passphrase. First start the server with a passphrase. If you are using `srt-live-transmit`, the passphrase has to be at least 10 characters long otherwise it will not be accepted.
@ -303,7 +314,7 @@ ffplay -f mpegts -transtype live -i "srt://127.0.0.1:6001?streamid=/live/stream&
You will most likely first see some error messages from `ffplay` because it tries to make sense of the received data until a keyframe arrives. If you
get more errors during playback, you might increase the receive buffer by adding e.g. `-rcvlatency 1000000` to the command line.
# Logging
## Logging
This SRT module has a built-in logging facility for debugging purposes. Check the `Logger` interface and the `NewLogger(topics []string)` function. Because logging everything would be too much output if you wonly want to debug something specific, you have the possibility to limit the logging to specific areas like everything regarding a connection or only the handshake. That's why there are various topics.
@ -380,7 +391,7 @@ packet:send:dump
You can run `make logtopics` in order to extract the list of topics.
# Docker
## Docker
The docker image you can build with `docker build -t srt .` provides the example SRT client and server as mentioned in the paragraph above.
E.g. run the server with `docker run -it --rm -p 6001:6001/udp srt srt-server -addr :6001`.

View File

@ -625,9 +625,9 @@ func (c *Config) MarshalQuery() string {
return q.Encode()
}
// Validate validates a configuration or returns an error if a field
// Validate validates a configuration, returns an error if a field
// has an invalid value.
func (c Config) Validate() error {
func (c *Config) Validate() error {
if c.TransmissionType != "live" {
return fmt.Errorf("config: TransmissionType must be 'live'")
}

View File

@ -54,6 +54,9 @@ type Conn interface {
// Stats returns accumulated and instantaneous statistics of the connection.
Stats(s *Statistics)
// Version returns the connection version, either 4 or 5. With version 4, the streamid is not available
Version() uint32
}
type connStats struct {
@ -80,6 +83,9 @@ type connStats struct {
var _ net.Conn = &srtConn{}
type srtConn struct {
version uint32
isCaller bool // Only relevant if version == 4
localAddr net.Addr
remoteAddr net.Addr
@ -155,9 +161,15 @@ type srtConn struct {
expectedRcvPacketSequenceNumber circular.Number
expectedReadPacketSequenceNumber circular.Number
}
// HSv4
stopHSRequests context.CancelFunc
stopKMRequests context.CancelFunc
}
type srtConnConfig struct {
version uint32
isCaller bool
localAddr net.Addr
remoteAddr net.Addr
config Config
@ -177,6 +189,8 @@ type srtConnConfig struct {
func newSRTConn(config srtConnConfig) *srtConn {
c := &srtConn{
version: config.version,
isCaller: config.isCaller,
localAddr: config.localAddr,
remoteAddr: config.remoteAddr,
config: config.config,
@ -217,7 +231,14 @@ func newSRTConn(config srtConnConfig) *srtConn {
c.networkQueue = make(chan packet.Packet, 1024)
c.writeQueue = make(chan packet.Packet, 1024)
c.writeData = make([]byte, int(c.config.PayloadSize))
if c.version == 4 {
// libsrt-1.2.3 receiver doesn't like it when the payload is larger than 7*188 bytes.
// Here we just take a multiple of a mpegts chunk size.
c.writeData = make([]byte, int(c.config.PayloadSize/188*188))
} else {
// For v5 we use the max. payload size: https://github.com/Haivision/srt/issues/876
c.writeData = make([]byte, int(c.config.PayloadSize))
}
c.readQueue = make(chan packet.Packet, 1024)
@ -281,6 +302,18 @@ func newSRTConn(config srtConnConfig) *srtConn {
c.statistics.headerSize += 40 // 40 bytes IPv6 header
}
if c.version == 4 && c.isCaller {
var hsrequestsCtx context.Context
hsrequestsCtx, c.stopHSRequests = context.WithCancel(context.Background())
go c.sendHSRequests(hsrequestsCtx)
if c.crypto != nil {
var kmrequestsCtx context.Context
kmrequestsCtx, c.stopKMRequests = context.WithCancel(context.Background())
go c.sendKMRequests(kmrequestsCtx)
}
}
return c
}
@ -306,6 +339,10 @@ func (c *srtConn) StreamId() string {
return c.config.StreamId
}
func (c *srtConn) Version() uint32 {
return c.version
}
// ticker invokes the congestion control in regular intervals with
// the current connection time.
func (c *srtConn) ticker(ctx context.Context) {
@ -475,7 +512,7 @@ func (c *srtConn) pop(p packet.Packet) {
c.kmRefreshCountdown--
if c.kmPreAnnounceCountdown == 0 && !c.kmConfirmed {
c.sendKMRequest()
c.sendKMRequest(c.keyBaseEncryption.Opposite())
// Resend the request until we get a response
c.kmPreAnnounceCountdown = c.config.KMPreAnnounce/10 + 1
@ -578,6 +615,17 @@ func (c *srtConn) handlePacket(p packet.Packet) {
} else if header.ControlType == packet.CTRLTYPE_ACKACK {
c.handleACKACK(p)
} else if header.ControlType == packet.CTRLTYPE_USER {
c.log("connection:recv:ctrl:user", func() string {
return fmt.Sprintf("got CTRLTYPE_USER packet, subType: %s", header.SubType)
})
// HSv4 Extension
if header.SubType == packet.EXTTYPE_HSREQ {
c.handleHSRequest(p)
} else if header.SubType == packet.EXTTYPE_HSRSP {
c.handleHSResponse(p)
}
// 3.2.2. Key Material
if header.SubType == packet.EXTTYPE_KMREQ {
c.handleKMRequest(p)
@ -775,35 +823,221 @@ func (c *srtConn) recalculateRTT(rtt time.Duration) {
})
}
// handleKMRequest checks if the key material is valid and responds with a KM response.
func (c *srtConn) handleKMRequest(p packet.Packet) {
c.log("control:recv:KM:dump", func() string { return p.Dump() })
// handleHSRequest handles the HSv4 handshake extension request and sends the response
func (c *srtConn) handleHSRequest(p packet.Packet) {
c.log("control:recv:HSReq:dump", func() string { return p.Dump() })
c.statistics.pktRecvKM++
c.cryptoLock.Lock()
if c.crypto == nil {
c.log("control:recv:KM:error", func() string { return "connection is not encrypted" })
c.cryptoLock.Unlock()
return
}
cif := &packet.CIFKM{}
cif := &packet.CIFHandshakeExtension{}
if err := p.UnmarshalCIF(cif); err != nil {
c.statistics.pktRecvInvalid++
c.log("control:recv:KM:error", func() string { return fmt.Sprintf("invalid KM: %s", err) })
c.log("control:recv:HSReq:error", func() string { return fmt.Sprintf("invalid HSReq: %s", err) })
return
}
c.log("control:recv:HSReq:cif", func() string { return cif.String() })
// Check for version
if cif.SRTVersion < 0x010200 || cif.SRTVersion >= 0x010300 {
c.log("control:recv:HSReq:error", func() string { return fmt.Sprintf("unsupported version: %#08x", cif.SRTVersion) })
c.close()
return
}
// Check the required SRT flags
if !cif.SRTFlags.TSBPDSND {
c.log("control:recv:HSRes:error", func() string { return "TSBPDSND flag must be set" })
c.close()
return
}
if !cif.SRTFlags.TLPKTDROP {
c.log("control:recv:HSRes:error", func() string { return "TLPKTDROP flag must be set" })
c.close()
return
}
if !cif.SRTFlags.CRYPT {
c.log("control:recv:HSRes:error", func() string { return "CRYPT flag must be set" })
c.close()
return
}
if !cif.SRTFlags.REXMITFLG {
c.log("control:recv:HSRes:error", func() string { return "REXMITFLG flag must be set" })
c.close()
return
}
// we as receiver don't need this
cif.SRTFlags.TSBPDSND = false
// we as receiver are supporting these
cif.SRTFlags.TSBPDRCV = true
cif.SRTFlags.PERIODICNAK = true
// These flag was introduced in HSv5 and should not be set in HSv4
if cif.SRTFlags.STREAM {
c.log("control:recv:HSReq:error", func() string { return "STREAM flag is set" })
c.close()
return
}
if cif.SRTFlags.PACKET_FILTER {
c.log("control:recv:HSReq:error", func() string { return "PACKET_FILTER flag is set" })
c.close()
return
}
recvTsbpdDelay := uint16(c.config.ReceiverLatency.Milliseconds())
if cif.SendTSBPDDelay > recvTsbpdDelay {
recvTsbpdDelay = cif.SendTSBPDDelay
}
c.tsbpdDelay = uint64(recvTsbpdDelay) * 1000
cif.RecvTSBPDDelay = 0
cif.SendTSBPDDelay = recvTsbpdDelay
p.MarshalCIF(cif)
// Send HS Response
p.Header().SubType = packet.EXTTYPE_HSRSP
c.pop(p)
}
// handleHSResponse handles the HSv4 handshake extension response
func (c *srtConn) handleHSResponse(p packet.Packet) {
c.log("control:recv:HSRes:dump", func() string { return p.Dump() })
cif := &packet.CIFHandshakeExtension{}
if err := p.UnmarshalCIF(cif); err != nil {
c.statistics.pktRecvInvalid++
c.log("control:recv:HSRes:error", func() string { return fmt.Sprintf("invalid HSRes: %s", err) })
return
}
c.log("control:recv:HSRes:cif", func() string { return cif.String() })
if c.version == 4 {
// Check for version
if cif.SRTVersion < 0x010200 || cif.SRTVersion >= 0x010300 {
c.log("control:recv:HSRes:error", func() string { return fmt.Sprintf("unsupported version: %#08x", cif.SRTVersion) })
c.close()
return
}
// TSBPDSND is not relevant from the receiver
// PERIODICNAK is the sender's decision, we don't care, but will handle them
// Check the required SRT flags
if !cif.SRTFlags.TSBPDRCV {
c.log("control:recv:HSRes:error", func() string { return "TSBPDRCV flag must be set" })
c.close()
return
}
if !cif.SRTFlags.TLPKTDROP {
c.log("control:recv:HSRes:error", func() string { return "TLPKTDROP flag must be set" })
c.close()
return
}
if !cif.SRTFlags.CRYPT {
c.log("control:recv:HSRes:error", func() string { return "CRYPT flag must be set" })
c.close()
return
}
if !cif.SRTFlags.REXMITFLG {
c.log("control:recv:HSRes:error", func() string { return "REXMITFLG flag must be set" })
c.close()
return
}
// These flag was introduced in HSv5 and should not be set in HSv4
if cif.SRTFlags.STREAM {
c.log("control:recv:HSReq:error", func() string { return "STREAM flag is set" })
c.close()
return
}
if cif.SRTFlags.PACKET_FILTER {
c.log("control:recv:HSReq:error", func() string { return "PACKET_FILTER flag is set" })
c.close()
return
}
sendTsbpdDelay := uint16(c.config.PeerLatency.Milliseconds())
if cif.SendTSBPDDelay > sendTsbpdDelay {
sendTsbpdDelay = cif.SendTSBPDDelay
}
c.dropThreshold = uint64(float64(sendTsbpdDelay)*1.25) + uint64(c.config.SendDropDelay.Microseconds())
if c.dropThreshold < uint64(time.Second.Microseconds()) {
c.dropThreshold = uint64(time.Second.Microseconds())
}
c.dropThreshold += 20_000
c.snd.SetDropThreshold(c.dropThreshold)
c.stopHSRequests()
}
}
// handleKMRequest checks if the key material is valid and responds with a KM response.
func (c *srtConn) handleKMRequest(p packet.Packet) {
c.log("control:recv:KMReq:dump", func() string { return p.Dump() })
c.statistics.pktRecvKM++
cif := &packet.CIFKeyMaterialExtension{}
if err := p.UnmarshalCIF(cif); err != nil {
c.statistics.pktRecvInvalid++
c.log("control:recv:KMReq:error", func() string { return fmt.Sprintf("invalid KMReq: %s", err) })
return
}
c.log("control:recv:KMReq:cif", func() string { return cif.String() })
c.cryptoLock.Lock()
if c.version == 4 && c.crypto == nil {
cr, err := crypto.New(int(cif.KLen))
if err != nil {
c.log("control:recv:KMReq:error", func() string { return fmt.Sprintf("crypto: %s", err) })
c.cryptoLock.Unlock()
c.close()
return
}
c.keyBaseEncryption = cif.KeyBasedEncryption.Opposite()
c.crypto = cr
}
if c.crypto == nil {
c.log("control:recv:KMReq:error", func() string { return "connection is not encrypted" })
c.cryptoLock.Unlock()
return
}
c.log("control:recv:KM:cif", func() string { return cif.String() })
if cif.KeyBasedEncryption == c.keyBaseEncryption {
c.statistics.pktRecvInvalid++
c.log("control:recv:KM:error", func() string {
return "invalid KM. wants to reset the key that is already in use"
c.log("control:recv:KMReq:error", func() string {
return "invalid KM request. wants to reset the key that is already in use"
})
c.cryptoLock.Unlock()
return
@ -811,7 +1045,7 @@ func (c *srtConn) handleKMRequest(p packet.Packet) {
if err := c.crypto.UnmarshalKM(cif, c.config.Passphrase); err != nil {
c.statistics.pktRecvInvalid++
c.log("control:recv:KM:error", func() string { return fmt.Sprintf("invalid KM: %s", err) })
c.log("control:recv:KMReq:error", func() string { return fmt.Sprintf("invalid KMReq: %s", err) })
c.cryptoLock.Unlock()
return
}
@ -831,20 +1065,44 @@ func (c *srtConn) handleKMRequest(p packet.Packet) {
// handleKMResponse confirms the change of encryption keys.
func (c *srtConn) handleKMResponse(p packet.Packet) {
c.log("control:recv:KM:dump", func() string { return p.Dump() })
c.log("control:recv:KMRes:dump", func() string { return p.Dump() })
c.statistics.pktRecvKM++
cif := &packet.CIFKeyMaterialExtension{}
if err := p.UnmarshalCIF(cif); err != nil {
c.statistics.pktRecvInvalid++
c.log("control:recv:KMRes:error", func() string { return fmt.Sprintf("invalid KMRes: %s", err) })
return
}
c.cryptoLock.Lock()
defer c.cryptoLock.Unlock()
if c.crypto == nil {
c.log("control:recv:KM:error", func() string { return "connection is not encrypted" })
c.log("control:recv:KMRes:error", func() string { return "connection is not encrypted" })
return
}
if c.version == 4 {
c.stopKMRequests()
if cif.Error != 0 {
if cif.Error == packet.KM_NOSECRET {
c.log("control:recv:KMRes:error", func() string { return "peer didn't enabled encryption" })
} else if cif.Error == packet.KM_BADSECRET {
c.log("control:recv:KMRes:error", func() string { return "peer has a different passphrase" })
}
c.close()
return
}
}
c.log("control:recv:KMRes:cif", func() string { return cif.String() })
if c.kmPreAnnounceCountdown >= c.config.KMPreAnnounce {
c.log("control:recv:KM:error", func() string { return "not in pre-announce period" })
c.log("control:recv:KMRes:error", func() string { return "not in pre-announce period, ignored" })
// Ignore the response, we're not in the pre-announce period
return
}
@ -964,16 +1222,73 @@ func (c *srtConn) sendACKACK(ackSequence uint32) {
c.pop(p)
}
func (c *srtConn) sendHSRequests(ctx context.Context) {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
select {
case <-ctx.Done():
return
case <-ticker.C:
c.sendHSRequest()
}
}
func (c *srtConn) sendHSRequest() {
cif := &packet.CIFHandshakeExtension{
SRTVersion: 0x00010203,
SRTFlags: packet.CIFHandshakeExtensionFlags{
TSBPDSND: true, // we send in TSBPD mode
TSBPDRCV: false, // not relevant for us as sender
CRYPT: true, // must be always set
TLPKTDROP: true, // must be set in live mode
PERIODICNAK: false, // not relevant for us as sender
REXMITFLG: true, // must alwasy be set
STREAM: false, // has been introducet in HSv5
PACKET_FILTER: false, // has been introducet in HSv5
},
RecvTSBPDDelay: 0,
SendTSBPDDelay: uint16(c.config.ReceiverLatency.Milliseconds()),
}
p := packet.NewPacket(c.remoteAddr, nil)
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_USER
p.Header().SubType = packet.EXTTYPE_HSREQ
p.Header().Timestamp = c.getTimestampForPacket()
p.MarshalCIF(cif)
c.log("control:send:HSReq:dump", func() string { return p.Dump() })
c.log("control:send:HSReq:cif", func() string { return cif.String() })
c.pop(p)
}
func (c *srtConn) sendKMRequests(ctx context.Context) {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
select {
case <-ctx.Done():
return
case <-ticker.C:
c.sendKMRequest(c.keyBaseEncryption)
}
}
// sendKMRequest sends a KM request to the peer.
func (c *srtConn) sendKMRequest() {
func (c *srtConn) sendKMRequest(key packet.PacketEncryption) {
if c.crypto == nil {
c.log("control:send:KM:error", func() string { return "connection is not encrypted" })
c.log("control:send:KMReq:error", func() string { return "connection is not encrypted" })
return
}
cif := &packet.CIFKM{}
cif := &packet.CIFKeyMaterialExtension{}
c.crypto.MarshalKM(cif, c.config.Passphrase, c.keyBaseEncryption.Opposite())
c.crypto.MarshalKM(cif, c.config.Passphrase, key)
p := packet.NewPacket(c.remoteAddr, nil)
@ -985,8 +1300,8 @@ func (c *srtConn) sendKMRequest() {
p.MarshalCIF(cif)
c.log("control:send:KM:dump", func() string { return p.Dump() })
c.log("control:send:KM:cif", func() string { return cif.String() })
c.log("control:send:KMReq:dump", func() string { return p.Dump() })
c.log("control:send:KMReq:cif", func() string { return cif.String() })
c.statistics.pktSentKM++

View File

@ -24,6 +24,8 @@ var ErrClientClosed = errors.New("srt: client closed")
// dialer implements the Conn interface
type dialer struct {
version uint32
pc *net.UDPConn
localAddr net.Addr
@ -331,28 +333,25 @@ func (dl *dialer) handleHandshake(p packet.Packet) {
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(dl.start).Microseconds())
p.Header().DestinationSocketId = cif.SRTSocketId
p.Header().DestinationSocketId = 0 // must be 0 for handshake
if cif.HandshakeType == packet.HSTYPE_INDUCTION {
// Verify version
if cif.Version != 5 {
if cif.Version < 4 || cif.Version > 5 {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support handshake v5"),
err: fmt.Errorf("peer responded with unsupported handshake version (%d)", cif.Version),
}
return
}
// Verify magic number
if cif.ExtensionField != 0x4A17 {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer sent the wrong magic number"),
}
return
}
cif.IsRequest = true
cif.HandshakeType = packet.HSTYPE_CONCLUSION
cif.InitialPacketSequenceNumber = dl.initialPacketSequenceNumber
cif.MaxTransmissionUnitSize = dl.config.MSS // MTU size
cif.MaxFlowWindowSize = dl.config.FC
cif.SRTSocketId = dl.socketId
cif.PeerIP.FromNetAddr(dl.localAddr)
// Setup crypto context
if len(dl.config.Passphrase) != 0 {
@ -382,42 +381,62 @@ func (dl *dialer) handleHandshake(p packet.Packet) {
dl.crypto = cr
}
cif.IsRequest = true
cif.HandshakeType = packet.HSTYPE_CONCLUSION
cif.InitialPacketSequenceNumber = dl.initialPacketSequenceNumber
cif.MaxTransmissionUnitSize = dl.config.MSS // MTU size
cif.MaxFlowWindowSize = dl.config.FC
cif.SRTSocketId = dl.socketId
cif.PeerIP.FromNetAddr(dl.localAddr)
// Verify version
if cif.Version == 5 {
dl.version = 5
cif.HasHS = true
cif.SRTVersion = SRT_VERSION
cif.SRTFlags.TSBPDSND = true
cif.SRTFlags.TSBPDRCV = true
cif.SRTFlags.CRYPT = true // must always set to true
cif.SRTFlags.TLPKTDROP = true
cif.SRTFlags.PERIODICNAK = true
cif.SRTFlags.REXMITFLG = true
cif.SRTFlags.STREAM = false
cif.SRTFlags.PACKET_FILTER = false
cif.RecvTSBPDDelay = uint16(dl.config.ReceiverLatency.Milliseconds())
cif.SendTSBPDDelay = uint16(dl.config.PeerLatency.Milliseconds())
cif.HasSID = true
cif.StreamId = dl.config.StreamId
if dl.crypto != nil {
cif.HasKM = true
cif.SRTKM = &packet.CIFKM{}
if err := dl.crypto.MarshalKM(cif.SRTKM, dl.config.Passphrase, packet.EvenKeyEncrypted); err != nil {
// Verify magic number
if cif.ExtensionField != 0x4A17 {
dl.connChan <- connResponse{
conn: nil,
err: err,
err: fmt.Errorf("peer sent the wrong magic number"),
}
return
}
cif.HasHS = true
cif.SRTHS = &packet.CIFHandshakeExtension{
SRTVersion: SRT_VERSION,
SRTFlags: packet.CIFHandshakeExtensionFlags{
TSBPDSND: true,
TSBPDRCV: true,
CRYPT: true, // must always set to true
TLPKTDROP: true,
PERIODICNAK: true,
REXMITFLG: true,
STREAM: false,
PACKET_FILTER: false,
},
RecvTSBPDDelay: uint16(dl.config.ReceiverLatency.Milliseconds()),
SendTSBPDDelay: uint16(dl.config.PeerLatency.Milliseconds()),
}
cif.HasSID = true
cif.StreamId = dl.config.StreamId
if dl.crypto != nil {
cif.HasKM = true
cif.SRTKM = &packet.CIFKeyMaterialExtension{}
if err := dl.crypto.MarshalKM(cif.SRTKM, dl.config.Passphrase, packet.EvenKeyEncrypted); err != nil {
dl.connChan <- connResponse{
conn: nil,
err: err,
}
return
}
}
} else {
dl.version = 4
cif.EncryptionField = 0
cif.ExtensionField = 2
cif.HasHS = false
cif.HasKM = false
cif.HasSID = false
}
p.MarshalCIF(cif)
@ -427,64 +446,63 @@ func (dl *dialer) handleHandshake(p packet.Packet) {
dl.send(p)
} else if cif.HandshakeType == packet.HSTYPE_CONCLUSION {
// We only support HSv5
if cif.Version != 5 {
dl.sendShutdown(cif.SRTSocketId)
if cif.Version < 4 || cif.Version > 5 {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support handshake v5"),
err: fmt.Errorf("peer responded with unsupported handshake version (%d)", cif.Version),
}
return
}
// Check if the peer version is sufficient
if cif.SRTVersion < dl.config.MinVersion {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer SRT version is not sufficient"),
}
return
}
// Check the required SRT flags
if !cif.SRTFlags.TSBPDSND || !cif.SRTFlags.TSBPDRCV || !cif.SRTFlags.TLPKTDROP || !cif.SRTFlags.PERIODICNAK || !cif.SRTFlags.REXMITFLG {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't agree on SRT flags"),
}
return
}
// We only support live streaming
if cif.SRTFlags.STREAM {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support live streaming"),
}
return
}
// Select the largest TSBPD delay advertised by the listener, but at least 120ms
recvTsbpdDelay := uint16(dl.config.ReceiverLatency.Milliseconds())
sendTsbpdDelay := uint16(dl.config.PeerLatency.Milliseconds())
if cif.SendTSBPDDelay > recvTsbpdDelay {
recvTsbpdDelay = cif.SendTSBPDDelay
}
if cif.Version == 5 {
// Check if the peer version is sufficient
if cif.SRTHS.SRTVersion < dl.config.MinVersion {
dl.sendShutdown(cif.SRTSocketId)
if cif.RecvTSBPDDelay > sendTsbpdDelay {
sendTsbpdDelay = cif.RecvTSBPDDelay
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer SRT version is not sufficient"),
}
return
}
// Check the required SRT flags
if !cif.SRTHS.SRTFlags.TSBPDSND || !cif.SRTHS.SRTFlags.TSBPDRCV || !cif.SRTHS.SRTFlags.TLPKTDROP || !cif.SRTHS.SRTFlags.PERIODICNAK || !cif.SRTHS.SRTFlags.REXMITFLG {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't agree on SRT flags"),
}
return
}
// We only support live streaming
if cif.SRTHS.SRTFlags.STREAM {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support live streaming"),
}
return
}
// Select the largest TSBPD delay advertised by the listener, but at least 120ms
if cif.SRTHS.SendTSBPDDelay > recvTsbpdDelay {
recvTsbpdDelay = cif.SRTHS.SendTSBPDDelay
}
if cif.SRTHS.RecvTSBPDDelay > sendTsbpdDelay {
sendTsbpdDelay = cif.SRTHS.RecvTSBPDDelay
}
}
// If the peer has a smaller MTU size, adjust to it
@ -506,6 +524,8 @@ func (dl *dialer) handleHandshake(p packet.Packet) {
// Create a new connection
conn := newSRTConn(srtConnConfig{
version: cif.Version,
isCaller: true,
localAddr: dl.localAddr,
remoteAddr: dl.remoteAddr,
config: dl.config,
@ -621,6 +641,10 @@ func (dl *dialer) StreamId() string {
return dl.conn.StreamId()
}
func (dl *dialer) Version() uint32 {
return dl.conn.Version()
}
func (dl *dialer) isShutdown() bool {
dl.shutdownLock.RLock()
defer dl.shutdownLock.RUnlock()

View File

@ -117,6 +117,10 @@ func (s *liveSend) Push(p packet.Packet) {
// give to the packet a sequence number
p.Header().PacketSequenceNumber = s.nextSequenceNumber
p.Header().PacketPositionFlag = packet.SinglePacket
p.Header().OrderFlag = false
p.Header().MessageNumber = 1
s.nextSequenceNumber = s.nextSequenceNumber.Inc()
pktLen := p.Len()

View File

@ -23,10 +23,10 @@ type Crypto interface {
// UnmarshalMK unwraps the key with the passphrase in a Key Material Extension Message. If the passphrase
// is wrong an error is returned.
UnmarshalKM(km *packet.CIFKM, passphrase string) error
UnmarshalKM(km *packet.CIFKeyMaterialExtension, passphrase string) error
// MarshalKM wraps the key with the passphrase and the odd/even SEK for a Key Material Extension Message.
MarshalKM(km *packet.CIFKM, passphrase string, key packet.PacketEncryption) error
MarshalKM(km *packet.CIFKeyMaterialExtension, passphrase string, key packet.PacketEncryption) error
// EncryptOrDecryptPayload encrypts or decrypts the data of a packet with an even or odd SEK and
// the sequence number.
@ -64,15 +64,17 @@ func New(keyLength int) (Crypto, error) {
return nil, fmt.Errorf("crypto: can't generate salt: %w", err)
}
c.evenSEK = make([]byte, c.keyLength)
if err := c.GenerateSEK(packet.EvenKeyEncrypted); err != nil {
sek, err := c.generateSEK(c.keyLength)
if err != nil {
return nil, err
}
c.evenSEK = sek
c.oddSEK = make([]byte, c.keyLength)
if err := c.GenerateSEK(packet.OddKeyEncrypted); err != nil {
sek, err = c.generateSEK(c.keyLength)
if err != nil {
return nil, err
}
c.oddSEK = sek
return c, nil
}
@ -82,26 +84,38 @@ func (c *crypto) GenerateSEK(key packet.PacketEncryption) error {
return fmt.Errorf("crypto: unknown key type")
}
sek, err := c.generateSEK(c.keyLength)
if err != nil {
return err
}
if key == packet.EvenKeyEncrypted {
if err := c.prng(c.evenSEK); err != nil {
return fmt.Errorf("crypto: can't generate even key: %w", err)
}
c.evenSEK = sek
} else if key == packet.OddKeyEncrypted {
if err := c.prng(c.oddSEK); err != nil {
return fmt.Errorf("crypto: can't generate odd key: %w", err)
}
c.oddSEK = sek
}
return nil
}
func (c *crypto) generateSEK(keyLength int) ([]byte, error) {
sek := make([]byte, keyLength)
err := c.prng(sek)
if err != nil {
return nil, fmt.Errorf("crypto: can't generate SEK: %w", err)
}
return sek, nil
}
// ErrInvalidKey is returned when the packet encryption is invalid
var ErrInvalidKey = errors.New("crypto: invalid key for encryption. Must be even, odd, or both")
// ErrInvalidWrap is returned when the packet encryption indicates a different length of the wrapped key
var ErrInvalidWrap = errors.New("crypto: the unwrapped key has the wrong length")
func (c *crypto) UnmarshalKM(km *packet.CIFKM, passphrase string) error {
func (c *crypto) UnmarshalKM(km *packet.CIFKeyMaterialExtension, passphrase string) error {
if km.KeyBasedEncryption == packet.UnencryptedPacket || !km.KeyBasedEncryption.IsValid() {
return ErrInvalidKey
}
@ -110,7 +124,7 @@ func (c *crypto) UnmarshalKM(km *packet.CIFKM, passphrase string) error {
copy(c.salt, km.Salt)
}
kek := c.calculateKEK(passphrase)
kek := c.calculateKEK(passphrase, c.salt, c.keyLength)
unwrap, err := keywrap.Unwrap(kek, km.Wrap)
if err != nil {
@ -138,7 +152,7 @@ func (c *crypto) UnmarshalKM(km *packet.CIFKM, passphrase string) error {
return nil
}
func (c *crypto) MarshalKM(km *packet.CIFKM, passphrase string, key packet.PacketEncryption) error {
func (c *crypto) MarshalKM(km *packet.CIFKeyMaterialExtension, passphrase string, key packet.PacketEncryption) error {
if key == packet.UnencryptedPacket || !key.IsValid() {
return ErrInvalidKey
}
@ -176,7 +190,7 @@ func (c *crypto) MarshalKM(km *packet.CIFKM, passphrase string, key packet.Packe
copy(w[c.keyLength:], c.oddSEK)
}
kek := c.calculateKEK(passphrase)
kek := c.calculateKEK(passphrase, c.salt, c.keyLength)
wrap, err := keywrap.Wrap(kek, w)
if err != nil {
@ -240,9 +254,9 @@ func (c *crypto) EncryptOrDecryptPayload(data []byte, key packet.PacketEncryptio
}
// calculateKEK calculates a KEK based on the passphrase.
func (c *crypto) calculateKEK(passphrase string) []byte {
func (c *crypto) calculateKEK(passphrase string, salt []byte, keyLength int) []byte {
// 6.1.4. Key Encrypting Key (KEK)
return pbkdf2.Key([]byte(passphrase), c.salt[8:], 2048, c.keyLength, sha1.New)
return pbkdf2.Key([]byte(passphrase), salt[8:], 2048, keyLength, sha1.New)
}
// prng generates a random sequence of byte into the given slice p.

File diff suppressed because it is too large Load Diff

View File

@ -46,6 +46,12 @@ type ConnRequest interface {
// is a copy and can be used at will.
RemoteAddr() net.Addr
// Version returns the handshake version of the incoming request. Currently
// known versions are 4 and 5. With version 4 the StreamId will always be
// empty and IsEncrypted will always return false. An incoming version 4
// connection will always be publishing.
Version() uint32
// StreamId returns the streamid of the requesting connection. Use this
// to decide what to do with the connection.
StreamId() string
@ -77,6 +83,10 @@ func (req *connRequest) RemoteAddr() net.Addr {
return addr
}
func (req *connRequest) Version() uint32 {
return req.handshake.Version
}
func (req *connRequest) StreamId() string {
return req.handshake.StreamId
}
@ -86,12 +96,14 @@ func (req *connRequest) IsEncrypted() bool {
}
func (req *connRequest) SetPassphrase(passphrase string) error {
if req.crypto == nil {
return fmt.Errorf("listen: request without encryption")
}
if req.handshake.Version == 5 {
if req.crypto == nil {
return fmt.Errorf("listen: request without encryption")
}
if err := req.crypto.UnmarshalKM(req.handshake.SRTKM, passphrase); err != nil {
return err
if err := req.crypto.UnmarshalKM(req.handshake.SRTKM, passphrase); err != nil {
return err
}
}
req.passphrase = passphrase
@ -312,23 +324,27 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
// Create a new socket ID
socketId := uint32(time.Since(ln.start).Microseconds())
// Select the largest TSBPD delay advertised by the listener, but at least 120ms
// Select the largest TSBPD delay advertised by the caller, but at least 120ms
recvTsbpdDelay := uint16(ln.config.ReceiverLatency.Milliseconds())
sendTsbpdDelay := uint16(ln.config.PeerLatency.Milliseconds())
if request.handshake.SendTSBPDDelay > recvTsbpdDelay {
recvTsbpdDelay = request.handshake.SendTSBPDDelay
if request.handshake.Version == 5 {
if request.handshake.SRTHS.SendTSBPDDelay > recvTsbpdDelay {
recvTsbpdDelay = request.handshake.SRTHS.SendTSBPDDelay
}
if request.handshake.SRTHS.RecvTSBPDDelay > sendTsbpdDelay {
sendTsbpdDelay = request.handshake.SRTHS.RecvTSBPDDelay
}
ln.config.StreamId = request.handshake.StreamId
}
if request.handshake.RecvTSBPDDelay > sendTsbpdDelay {
sendTsbpdDelay = request.handshake.RecvTSBPDDelay
}
ln.config.StreamId = request.handshake.StreamId
ln.config.Passphrase = request.passphrase
// Create a new connection
conn := newSRTConn(srtConnConfig{
version: request.handshake.Version,
localAddr: ln.addr,
remoteAddr: request.addr,
config: ln.config,
@ -351,24 +367,26 @@ func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
request.handshake.SRTSocketId = socketId
request.handshake.SynCookie = 0
// 3.2.1.1.1. Handshake Extension Message Flags
request.handshake.SRTVersion = 0x00010402
request.handshake.SRTFlags.TSBPDSND = true
request.handshake.SRTFlags.TSBPDRCV = true
request.handshake.SRTFlags.CRYPT = true
request.handshake.SRTFlags.TLPKTDROP = true
request.handshake.SRTFlags.PERIODICNAK = true
request.handshake.SRTFlags.REXMITFLG = true
request.handshake.SRTFlags.STREAM = false
request.handshake.SRTFlags.PACKET_FILTER = false
request.handshake.RecvTSBPDDelay = recvTsbpdDelay
request.handshake.SendTSBPDDelay = sendTsbpdDelay
if request.handshake.Version == 5 {
// 3.2.1.1.1. Handshake Extension Message Flags
request.handshake.SRTHS.SRTVersion = SRT_VERSION
request.handshake.SRTHS.SRTFlags.TSBPDSND = true
request.handshake.SRTHS.SRTFlags.TSBPDRCV = true
request.handshake.SRTHS.SRTFlags.CRYPT = true
request.handshake.SRTHS.SRTFlags.TLPKTDROP = true
request.handshake.SRTHS.SRTFlags.PERIODICNAK = true
request.handshake.SRTHS.SRTFlags.REXMITFLG = true
request.handshake.SRTHS.SRTFlags.STREAM = false
request.handshake.SRTHS.SRTFlags.PACKET_FILTER = false
request.handshake.SRTHS.RecvTSBPDDelay = recvTsbpdDelay
request.handshake.SRTHS.SendTSBPDDelay = sendTsbpdDelay
}
ln.accept(request)
// Add the connection to the list of known connections
ln.lock.Lock()
ln.conns[conn.socketId] = conn
ln.conns[socketId] = conn
ln.lock.Unlock()
return conn, mode, nil
@ -574,7 +592,7 @@ func (ln *listener) handleHandshake(p packet.Packet) {
//cif.initialPacketSequenceNumber = newCircular(0, MAX_SEQUENCENUMBER)
//cif.maxTransmissionUnitSize = 0
//cif.maxFlowWindowSize = 0
cif.SRTSocketId = 0
//cif.SRTSocketId = 0
cif.SynCookie = ln.syncookie.Get(p.Header().Addr.String())
p.MarshalCIF(cif)
@ -596,56 +614,6 @@ func (ln *listener) handleHandshake(p packet.Packet) {
return
}
// We only support HSv5
if cif.Version != 5 {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return "only HSv5 is supported" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Check if the peer version is sufficient
if cif.SRTVersion < ln.config.MinVersion {
cif.HandshakeType = packet.REJ_VERSION
ln.log("handshake:recv:error", func() string {
return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTVersion, ln.config.MinVersion)
})
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Check the required SRT flags
if !cif.SRTFlags.TSBPDSND || !cif.SRTFlags.TSBPDRCV || !cif.SRTFlags.TLPKTDROP || !cif.SRTFlags.PERIODICNAK || !cif.SRTFlags.REXMITFLG {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return "not all required flags are set" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// We only support live streaming
if cif.SRTFlags.STREAM {
cif.HandshakeType = packet.REJ_MESSAGEAPI
ln.log("handshake:recv:error", func() string { return "only live streaming is supported" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Peer is advertising a too big MSS
if cif.MaxTransmissionUnitSize > MAX_MSS_SIZE {
cif.HandshakeType = packet.REJ_ROGUE
@ -673,6 +641,68 @@ func (ln *listener) handleHandshake(p packet.Packet) {
}
}
// We only support HSv4 and HSv5
if cif.Version == 4 {
// Check if the type (encryption field + extension field) has the value 2
if cif.EncryptionField != 0 || cif.ExtensionField != 2 {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return "invalid type, expecting a value of 2 (UDT_DGRAM)" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
} else if cif.Version == 5 {
// Check if the peer version is sufficient
if cif.SRTHS.SRTVersion < ln.config.MinVersion {
cif.HandshakeType = packet.REJ_VERSION
ln.log("handshake:recv:error", func() string {
return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTHS.SRTVersion, ln.config.MinVersion)
})
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Check the required SRT flags
if !cif.SRTHS.SRTFlags.TSBPDSND || !cif.SRTHS.SRTFlags.TSBPDRCV || !cif.SRTHS.SRTFlags.TLPKTDROP || !cif.SRTHS.SRTFlags.PERIODICNAK || !cif.SRTHS.SRTFlags.REXMITFLG {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return "not all required flags are set" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// We only support live streaming
if cif.SRTHS.SRTFlags.STREAM {
cif.HandshakeType = packet.REJ_MESSAGEAPI
ln.log("handshake:recv:error", func() string { return "only live streaming is supported" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
} else {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("only HSv4 and HSv5 are supported (got HSv%d)", cif.Version) })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Fill up a connection request with all relevant data and put it into the backlog
c := connRequest{

View File

@ -0,0 +1 @@
testdata/* linguist-vendored

View File

@ -0,0 +1,76 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at vasile.gabriel@email.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

Some files were not shown because too many files have changed in this diff Show More