diff --git a/cluster/cluster.go b/cluster/cluster.go index ecf3745d..446cb97a 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -1,33 +1,24 @@ package cluster import ( - "bytes" "context" - "encoding/gob" "encoding/json" "fmt" "io" gonet "net" "net/url" - "path/filepath" - "reflect" "strconv" "sync" "time" apiclient "github.com/datarhei/core/v16/cluster/client" "github.com/datarhei/core/v16/cluster/forwarder" - raftlogger "github.com/datarhei/core/v16/cluster/logger" "github.com/datarhei/core/v16/cluster/proxy" + "github.com/datarhei/core/v16/cluster/raft" "github.com/datarhei/core/v16/cluster/store" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/net" "github.com/datarhei/core/v16/restream/app" - - hclog "github.com/hashicorp/go-hclog" - "github.com/hashicorp/raft" - raftboltdb "github.com/hashicorp/raft-boltdb/v2" - "go.etcd.io/bbolt" ) /* @@ -105,23 +96,17 @@ type cluster struct { logger log.Logger - raft *raft.Raft - raftTransport *raft.NetworkTransport - raftAddress string - raftNotifyCh chan bool - raftEmergencyNotifyCh chan bool - raftStore *raftboltdb.BoltStore - raftRemoveGracePeriod time.Duration - - peers []Peer + raft raft.Raft + raftRemoveGracePeriod time.Duration + raftAddress string + raftNotifyCh chan bool + raftEmergencyNotifyCh chan bool + raftLeaderObservationCh chan string store store.Store - reassertLeaderCh chan chan error cancelLeaderShip context.CancelFunc - leaveCh chan struct{} - shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex @@ -148,12 +133,10 @@ func New(config ClusterConfig) (Cluster, error) { path: config.Path, logger: config.Logger, - raftAddress: config.Address, - peers: config.Peers, + raftAddress: config.Address, + raftRemoveGracePeriod: 5 * time.Second, - reassertLeaderCh: make(chan chan error), - leaveCh: make(chan struct{}), - shutdownCh: make(chan struct{}), + shutdownCh: make(chan struct{}), nodes: map[string]proxy.Node{}, } @@ -204,7 +187,7 @@ func New(config ClusterConfig) (Cluster, error) { Logger: c.logger.WithField("logname", "proxy"), }) if err != nil { - c.shutdownAPI() + c.Shutdown() return nil, err } @@ -214,13 +197,11 @@ func New(config ClusterConfig) (Cluster, error) { c.proxy = nodeproxy - go c.trackNodeChanges() - if forwarder, err := forwarder.New(forwarder.ForwarderConfig{ ID: c.id, Logger: c.logger.WithField("logname", "forwarder"), }); err != nil { - c.shutdownAPI() + c.Shutdown() return nil, err } else { c.forwarder = forwarder @@ -228,18 +209,43 @@ func New(config ClusterConfig) (Cluster, error) { c.logger.Debug().Log("starting raft") - err = c.startRaft(store, config.Bootstrap, config.Recover, config.Peers, false) + peers := []raft.Peer{} + + for _, p := range config.Peers { + peers = append(peers, raft.Peer{ + ID: p.ID, + Address: p.Address, + }) + } + + c.raftNotifyCh = make(chan bool, 16) + c.raftLeaderObservationCh = make(chan string, 16) + c.raftEmergencyNotifyCh = make(chan bool, 16) + + raft, err := raft.New(raft.Config{ + ID: config.ID, + Path: config.Path, + Bootstrap: config.Bootstrap, + Recover: config.Recover, + Address: config.Address, + Peers: peers, + Store: store, + LeadershipNotifyCh: c.raftNotifyCh, + LeaderObservationCh: c.raftLeaderObservationCh, + Logger: c.logger.WithComponent("raft"), + }) if err != nil { - c.shutdownAPI() + c.Shutdown() return nil, err } + c.raft = raft + if len(config.Peers) != 0 { for i := 0; i < len(config.Peers); i++ { peerAddress, err := c.ClusterAPIAddress(config.Peers[i].Address) if err != nil { - c.shutdownAPI() - c.shutdownRaft() + c.Shutdown() return nil, err } @@ -265,6 +271,11 @@ func New(config ClusterConfig) (Cluster, error) { } } + go c.trackNodeChanges() + go c.trackLeaderChanges() + go c.monitorLeadership() + go c.sentinel() + return c, nil } @@ -324,15 +335,26 @@ func (c *cluster) Shutdown() error { for id, node := range c.nodes { node.Disconnect() - c.proxy.RemoveNode(id) + if c.proxy != nil { + c.proxy.RemoveNode(id) + } } if c.proxy != nil { c.proxy.Stop() } - c.shutdownAPI() - c.shutdownRaft() + if c.api != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + c.api.Shutdown(ctx) + } + + if c.raft != nil { + c.raft.Shutdown() + c.raft = nil + } return nil } @@ -369,16 +391,16 @@ func (c *cluster) Leave(origin, id string) error { time.Sleep(50 * time.Millisecond) // Get the latest configuration. - future := c.raft.GetConfiguration() - if err := future.Error(); err != nil { + servers, err := c.raft.Servers() + if err != nil { c.logger.Error().WithError(err).Log("failed to get raft configuration") break } // See if we are no longer included. left = true - for _, server := range future.Configuration().Servers { - if server.Address == raft.ServerAddress(c.raftAddress) { + for _, server := range servers { + if server.Address == c.raftAddress { left = false break } @@ -393,13 +415,13 @@ func (c *cluster) Leave(origin, id string) error { } // Count the number of servers in the cluster - future := c.raft.GetConfiguration() - if err := future.Error(); err != nil { + servers, err := c.raft.Servers() + if err != nil { c.logger.Error().WithError(err).Log("failed to get raft configuration") return err } - numPeers := len(future.Configuration().Servers) + numPeers := len(servers) if id == c.id { // We're going to remove ourselves @@ -446,16 +468,16 @@ func (c *cluster) Leave(origin, id string) error { time.Sleep(50 * time.Millisecond) // Get the latest configuration. - future := c.raft.GetConfiguration() - if err := future.Error(); err != nil { + servers, err := c.raft.Servers() + if err != nil { c.logger.Error().WithError(err).Log("failed to get raft configuration") break } // See if we are no longer included. left = true - for _, server := range future.Configuration().Servers { - if server.Address == raft.ServerAddress(c.raftAddress) { + for _, server := range servers { + if server.Address == c.raftAddress { left = false break } @@ -466,12 +488,13 @@ func (c *cluster) Leave(origin, id string) error { } // Remove another sever from the cluster - if future := c.raft.RemoveServer(raft.ServerID(id), 0, 0); future.Error() != nil { - c.logger.Error().WithError(future.Error()).WithFields(log.Fields{ + err = c.raft.RemoveServer(id) + if err != nil { + c.logger.Error().WithError(err).WithFields(log.Fields{ "nodeid": id, }).Log("failed to remove node") - return future.Error() + return err } return nil @@ -503,29 +526,29 @@ func (c *cluster) Join(origin, id, raftAddress, peerAddress string) error { } defer node.Disconnect() - configFuture := c.raft.GetConfiguration() - if err := configFuture.Error(); err != nil { + servers, err := c.raft.Servers() + if err != nil { c.logger.Error().WithError(err).Log("failed to get raft configuration") return err } nodeExists := false - for _, srv := range configFuture.Configuration().Servers { + for _, srv := range servers { // If a node already exists with either the joining node's ID or address, // that node may need to be removed from the config first. - if srv.ID == raft.ServerID(id) || srv.Address == raft.ServerAddress(raftAddress) { + if srv.ID == id || srv.Address == raftAddress { // However if *both* the ID and the address are the same, then nothing -- not even // a join operation -- is needed. - if srv.ID == raft.ServerID(id) && srv.Address == raft.ServerAddress(raftAddress) { + if srv.ID == id && srv.Address == raftAddress { nodeExists = true c.logger.Debug().WithFields(log.Fields{ "nodeid": id, "address": raftAddress, }).Log("node is already member of cluster, ignoring join request") } else { - future := c.raft.RemoveServer(srv.ID, 0, 0) - if err := future.Error(); err != nil { + err := c.raft.RemoveServer(srv.ID) + if err != nil { c.logger.Error().WithError(err).WithFields(log.Fields{ "nodeid": id, "address": raftAddress, @@ -537,8 +560,8 @@ func (c *cluster) Join(origin, id, raftAddress, peerAddress string) error { } if !nodeExists { - f := c.raft.AddVoter(raft.ServerID(id), raft.ServerAddress(raftAddress), 0, 0) - if err := f.Error(); err != nil { + err := c.raft.AddServer(id, raftAddress) + if err != nil { return err } } @@ -551,60 +574,13 @@ func (c *cluster) Join(origin, id, raftAddress, peerAddress string) error { return nil } -type Snapshot struct { - Metadata *raft.SnapshotMeta - Data []byte -} - func (c *cluster) Snapshot() (io.ReadCloser, error) { if !c.IsRaftLeader() { c.logger.Debug().Log("not leader, forwarding to leader") return c.forwarder.Snapshot() } - f := c.raft.Snapshot() - err := f.Error() - if err != nil { - return nil, err - } - - metadata, r, err := f.Open() - if err != nil { - return nil, err - } - - defer r.Close() - - data, err := io.ReadAll(r) - if err != nil { - return nil, fmt.Errorf("failed to read in snapshot: %w", err) - } - - snapshot := Snapshot{ - Metadata: metadata, - Data: data, - } - - buffer := bytes.Buffer{} - enc := gob.NewEncoder(&buffer) - err = enc.Encode(snapshot) - if err != nil { - return nil, err - } - - return &readCloserWrapper{&buffer}, nil -} - -type readCloserWrapper struct { - io.Reader -} - -func (rcw *readCloserWrapper) Read(p []byte) (int, error) { - return rcw.Reader.Read(p) -} - -func (rcw *readCloserWrapper) Close() error { - return nil + return c.raft.Snapshot() } func (c *cluster) trackNodeChanges() { @@ -615,8 +591,8 @@ func (c *cluster) trackNodeChanges() { select { case <-ticker.C: // Get the latest configuration. - future := c.raft.GetConfiguration() - if err := future.Error(); err != nil { + servers, err := c.raft.Servers() + if err != nil { c.logger.Error().WithError(err).Log("failed to get raft configuration") continue } @@ -628,12 +604,12 @@ func (c *cluster) trackNodeChanges() { removeNodes[id] = struct{}{} } - for _, server := range future.Configuration().Servers { - id := string(server.ID) + for _, server := range servers { + id := server.ID _, ok := c.nodes[id] if !ok { - address, err := c.CoreAPIAddress(string(server.Address)) + address, err := c.CoreAPIAddress(server.Address) if err != nil { c.logger.Warn().WithError(err).WithFields(log.Fields{ "id": id, @@ -686,230 +662,23 @@ func (c *cluster) trackNodeChanges() { // trackLeaderChanges registers an Observer with raft in order to receive updates // about leader changes, in order to keep the forwarder up to date. func (c *cluster) trackLeaderChanges() { - obsCh := make(chan raft.Observation, 16) - observer := raft.NewObserver(obsCh, false, func(o *raft.Observation) bool { - _, leaderOK := o.Data.(raft.LeaderObservation) - _, peerOK := o.Data.(raft.PeerObservation) - - return leaderOK || peerOK - }) - c.raft.RegisterObserver(observer) - for { select { - case obs := <-obsCh: - if leaderObs, ok := obs.Data.(raft.LeaderObservation); ok { - c.logger.Debug().WithFields(log.Fields{ - "id": leaderObs.LeaderID, - "address": leaderObs.LeaderAddr, - }).Log("new leader observation") - addr := string(leaderObs.LeaderAddr) - if len(addr) != 0 { - addr, _ = c.ClusterAPIAddress(addr) - } - c.forwarder.SetLeader(addr) - c.leaderLock.Lock() - if len(addr) == 0 { - c.hasRaftLeader = false - } else { - c.hasRaftLeader = true - } - c.leaderLock.Unlock() - } else if peerObs, ok := obs.Data.(raft.PeerObservation); ok { - c.logger.Debug().WithFields(log.Fields{ - "removed": peerObs.Removed, - "address": peerObs.Peer.Address, - }).Log("new peer observation") - } else { - c.logger.Debug().WithField("type", reflect.TypeOf(obs.Data)).Log("got unknown observation type from raft") - continue + case leaderAddress := <-c.raftLeaderObservationCh: + c.logger.Debug().WithFields(log.Fields{ + "address": leaderAddress, + }).Log("new leader observation") + if len(leaderAddress) != 0 { + leaderAddress, _ = c.ClusterAPIAddress(leaderAddress) } - case <-c.shutdownCh: - c.raft.DeregisterObserver(observer) - return - } - } -} - -func (c *cluster) startRaft(fsm raft.FSM, bootstrap, recover bool, peers []Peer, inmem bool) error { - defer func() { - if c.raft == nil && c.raftStore != nil { - c.raftStore.Close() - } - }() - - c.raftRemoveGracePeriod = 5 * time.Second - - addr, err := gonet.ResolveTCPAddr("tcp", c.raftAddress) - if err != nil { - return err - } - - c.logger.Debug().Log("address: %s", addr) - - transport, err := raft.NewTCPTransportWithLogger(c.raftAddress, addr, 3, 10*time.Second, raftlogger.New(c.logger, hclog.Debug).Named("raft-transport")) - if err != nil { - return err - } - - c.raftTransport = transport - - snapshotLogger := raftlogger.New(c.logger, hclog.Debug).Named("raft-snapshot") - snapshots, err := raft.NewFileSnapshotStoreWithLogger(c.path, 3, snapshotLogger) - if err != nil { - return err - } - - var logStore raft.LogStore - var stableStore raft.StableStore - if inmem { - logStore = raft.NewInmemStore() - stableStore = raft.NewInmemStore() - } else { - bolt, err := raftboltdb.New(raftboltdb.Options{ - Path: filepath.Join(c.path, "raftlog.db"), - BoltOptions: &bbolt.Options{ - Timeout: 5 * time.Second, - }, - }) - if err != nil { - return fmt.Errorf("bolt: %w", err) - } - logStore = bolt - stableStore = bolt - - cacheStore, err := raft.NewLogCache(512, logStore) - if err != nil { - return err - } - logStore = cacheStore - - c.raftStore = bolt - } - - cfg := raft.DefaultConfig() - cfg.LocalID = raft.ServerID(c.id) - cfg.Logger = raftlogger.New(c.logger, hclog.Debug).Named("raft") - - hasState, err := raft.HasExistingState(logStore, stableStore, snapshots) - if err != nil { - return err - } - - if !hasState { - // Bootstrap cluster - servers := []raft.Server{ - { - Suffrage: raft.Voter, - ID: raft.ServerID(c.id), - Address: transport.LocalAddr(), - }, - } - - for _, p := range peers { - servers = append(servers, raft.Server{ - Suffrage: raft.Voter, - ID: raft.ServerID(p.ID), - Address: raft.ServerAddress(p.Address), - }) - } - - configuration := raft.Configuration{ - Servers: servers, - } - - if err := raft.BootstrapCluster(cfg, logStore, stableStore, snapshots, transport, configuration); err != nil { - return err - } - - c.logger.Debug().Log("raft node bootstrapped") - } else { - // Recover cluster - fsm, err := store.NewStore() - if err != nil { - return err - } - - servers := []raft.Server{ - { - Suffrage: raft.Voter, - ID: raft.ServerID(c.id), - Address: transport.LocalAddr(), - }, - } - - for _, p := range peers { - servers = append(servers, raft.Server{ - Suffrage: raft.Voter, - ID: raft.ServerID(p.ID), - Address: raft.ServerAddress(p.Address), - }) - } - - configuration := raft.Configuration{ - Servers: servers, - } - - if err := raft.RecoverCluster(cfg, fsm, logStore, stableStore, snapshots, transport, configuration); err != nil { - return err - } - - c.logger.Debug().Log("raft node recoverd") - } - - // Set up a channel for reliable leader notifications. - raftNotifyCh := make(chan bool, 10) - cfg.NotifyCh = raftNotifyCh - c.raftNotifyCh = raftNotifyCh - - c.raftEmergencyNotifyCh = make(chan bool, 10) - - node, err := raft.NewRaft(cfg, fsm, logStore, stableStore, snapshots, transport) - if err != nil { - return err - } - - c.raft = node - - go c.trackLeaderChanges() - go c.monitorLeadership() - go c.sentinel() - - c.logger.Debug().Log("raft started") - - return nil -} - -func (c *cluster) shutdownRaft() { - if c.raft != nil { - c.raftTransport.Close() - future := c.raft.Shutdown() - if err := future.Error(); err != nil { - c.logger.Warn().WithError(err).Log("error shutting down raft") - } - if c.raftStore != nil { - c.raftStore.Close() - } - } -} - -func (c *cluster) shutdownAPI() { - if c.api != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - c.api.Shutdown(ctx) - } -} - -// nodeLoop is run by every node in the cluster. This is mainly to check the list -// of nodes from the FSM, in order to connect to them and to fetch their file lists. -func (c *cluster) followerLoop(stopCh chan struct{}) { - // Periodically reconcile as long as we are the leader - for { - select { - case <-stopCh: - return + c.forwarder.SetLeader(leaderAddress) + c.leaderLock.Lock() + if len(leaderAddress) == 0 { + c.hasRaftLeader = false + } else { + c.hasRaftLeader = true + } + c.leaderLock.Unlock() case <-c.shutdownCh: return } @@ -952,8 +721,8 @@ func (c *cluster) applyCommand(cmd *store.Command) error { return err } - future := c.raft.Apply(b, 5*time.Second) - if err := future.Error(); err != nil { + err = c.raft.Apply(b) + if err != nil { return fmt.Errorf("applying command failed: %w", err) } @@ -998,30 +767,22 @@ func (c *cluster) About() (ClusterAbout, error) { stats := c.raft.Stats() - about.Stats.State = stats["state"] + about.Stats.State = stats.State + about.Stats.LastContact = stats.LastContact + about.Stats.NumPeers = stats.NumPeers - if x, err := time.ParseDuration(stats["last_contact"]); err == nil { - about.Stats.LastContact = x - } - - if x, err := strconv.ParseUint(stats["num_peers"], 10, 64); err == nil { - about.Stats.NumPeers = x - } - - _, leaderID := c.raft.LeaderWithID() - - future := c.raft.GetConfiguration() - if err := future.Error(); err != nil { + servers, err := c.raft.Servers() + if err != nil { c.logger.Error().WithError(err).Log("failed to get raft configuration") return ClusterAbout{}, err } - for _, server := range future.Configuration().Servers { + for _, server := range servers { node := ClusterServer{ - ID: string(server.ID), - Address: string(server.Address), - Voter: server.Suffrage == raft.Voter, - Leader: server.ID == leaderID, + ID: server.ID, + Address: server.Address, + Voter: server.Voter, + Leader: server.Leader, } about.Nodes = append(about.Nodes, node) @@ -1034,9 +795,6 @@ func (c *cluster) sentinel() { ticker := time.NewTicker(time.Second) defer ticker.Stop() - start := time.Now() - var lastContactSince time.Duration - isEmergencyLeader := false for { @@ -1046,30 +804,17 @@ func (c *cluster) sentinel() { case <-ticker.C: stats := c.raft.Stats() - fields := log.Fields{} - for k, v := range stats { - fields[k] = v - } + c.logger.Debug().WithFields(log.Fields{ + "state": stats.State, + "last_contact": stats.LastContact.String(), + "num_peers": stats.NumPeers, + }).Log("stats") - c.logger.Debug().WithFields(fields).Log("stats") - - lastContact := stats["last_contact"] - if lastContact == "never" { - lastContactSince = time.Since(start) - } else { - if d, err := time.ParseDuration(lastContact); err == nil { - lastContactSince = d - start = time.Now() - } else { - lastContactSince = time.Since(start) - } - } - - if lastContactSince > 10*time.Second && !isEmergencyLeader { + if stats.LastContact > 10*time.Second && !isEmergencyLeader { c.logger.Warn().Log("force leadership due to lost contact to leader") c.raftEmergencyNotifyCh <- true isEmergencyLeader = true - } else if lastContactSince <= 10*time.Second && isEmergencyLeader { + } else if stats.LastContact <= 10*time.Second && isEmergencyLeader { c.logger.Warn().Log("stop forced leadership due to contact to leader") c.raftEmergencyNotifyCh <- false isEmergencyLeader = false diff --git a/cluster/follower.go b/cluster/follower.go new file mode 100644 index 00000000..f62b5841 --- /dev/null +++ b/cluster/follower.go @@ -0,0 +1,14 @@ +package cluster + +// followerLoop is run by every follower node in the cluster. +func (c *cluster) followerLoop(stopCh chan struct{}) { + // Periodically reconcile as long as we are the leader + for { + select { + case <-stopCh: + return + case <-c.shutdownCh: + return + } + } +} diff --git a/cluster/leader.go b/cluster/leader.go index 5db6a0ff..c5ce188e 100644 --- a/cluster/leader.go +++ b/cluster/leader.go @@ -17,9 +17,6 @@ const NOTIFY_FOLLOWER = 0 const NOTIFY_LEADER = 1 const NOTIFY_EMERGENCY = 2 -// monitorLeadership listens to the raf notify channel in order to find -// out if we got the leadership or lost it. -// https://github.com/hashicorp/consul/blob/44b39240a86bc94ddc67bc105286ab450bd869a9/agent/consul/leader.go#L71 func (c *cluster) monitorLeadership() { // We use the notify channel we configured Raft with, NOT Raft's // leaderCh, which is only notified best-effort. Doing this ensures @@ -187,12 +184,11 @@ func (c *cluster) monitorLeadership() { // leadershipTransfer tries to transfer the leadership to another node e.g. in order // to do a graceful shutdown. -// https://github.com/hashicorp/consul/blob/44b39240a86bc94ddc67bc105286ab450bd869a9/agent/consul/leader.go#L122 func (c *cluster) leadershipTransfer() error { retryCount := 3 for i := 0; i < retryCount; i++ { - future := c.raft.LeadershipTransfer() - if err := future.Error(); err != nil { + err := c.raft.LeadershipTransfer() + if err != nil { c.logger.Error().WithError(err).WithFields(log.Fields{ "attempt": i, "retry_limit": retryCount, @@ -234,8 +230,8 @@ RECONCILE: // Apply a raft barrier to ensure our FSM is caught up if !emergency { - barrier := c.raft.Barrier(time.Minute) - if err := barrier.Error(); err != nil { + err := c.raft.Barrier(time.Minute) + if err != nil { c.logger.Error().WithError(err).Log("failed to wait for barrier") goto WAIT } diff --git a/cluster/raft/raft.go b/cluster/raft/raft.go new file mode 100644 index 00000000..2320ef2e --- /dev/null +++ b/cluster/raft/raft.go @@ -0,0 +1,512 @@ +package raft + +import ( + "bytes" + "encoding/gob" + "fmt" + "io" + gonet "net" + "path/filepath" + "reflect" + "strconv" + "sync" + "time" + + raftlogger "github.com/datarhei/core/v16/cluster/logger" + "github.com/datarhei/core/v16/cluster/store" + "github.com/datarhei/core/v16/log" + "go.etcd.io/bbolt" + + "github.com/hashicorp/go-hclog" + hcraft "github.com/hashicorp/raft" + raftboltdb "github.com/hashicorp/raft-boltdb/v2" +) + +type Raft interface { + Shutdown() + IsLeader() bool + Leader() string + Servers() ([]Server, error) + Stats() Stats + Apply([]byte) error + + Barrier(time.Duration) error + + AddServer(id, address string) error + RemoveServer(id string) error + LeadershipTransfer() error + + Snapshot() (io.ReadCloser, error) +} + +type raft struct { + id string + path string + + raft *hcraft.Raft + raftTransport *hcraft.NetworkTransport + raftAddress string + raftNotifyCh chan bool + raftStore *raftboltdb.BoltStore + raftStart time.Time + + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex + + leadershipNotifyCh chan bool + leaderObservationCh chan string + + isLeader bool + leaderLock sync.Mutex + + logger log.Logger +} + +type Peer struct { + ID string + Address string +} + +type Server struct { + ID string + Address string + Voter bool + Leader bool +} + +type Stats struct { + State string + LastContact time.Duration + NumPeers uint64 +} + +type Config struct { + ID string // ID of the node + Path string // Path where to store all cluster data + Bootstrap bool // Whether to bootstrap a cluster + Recover bool // Whether to recover this node + Address string // Listen address for the raft protocol + Peers []Peer // Address of a member of a cluster to join + + Store hcraft.FSM + + // A channel where to push "true" if this server is the leader + // and "false" if this server is not the leader. + LeadershipNotifyCh chan bool + + // A channel where to push leader observations. An observation is + // the address of the leader or an empty string if there's currently + // no leader. + LeaderObservationCh chan string + + Logger log.Logger +} + +func New(config Config) (Raft, error) { + r := &raft{ + id: config.ID, + path: config.Path, + raftAddress: config.Address, + leadershipNotifyCh: config.LeadershipNotifyCh, + leaderObservationCh: config.LeaderObservationCh, + } + + err := r.start(config.Store, config.Bootstrap, config.Recover, config.Peers, false) + if err != nil { + return nil, fmt.Errorf("failed to start raft: %w", err) + } + + r.raftStart = time.Now() + + return r, nil +} + +func (r *raft) Shutdown() { + r.shutdownLock.Lock() + defer r.shutdownLock.Unlock() + + if r.shutdown { + return + } + + r.shutdown = true + close(r.shutdownCh) + + if r.raft != nil { + r.raftTransport.Close() + future := r.raft.Shutdown() + if err := future.Error(); err != nil { + r.logger.Warn().WithError(err).Log("Shutting down raft") + } + if r.raftStore != nil { + r.raftStore.Close() + } + } +} + +func (r *raft) IsLeader() bool { + r.leaderLock.Lock() + defer r.leaderLock.Unlock() + + return r.isLeader +} + +func (r *raft) Leader() string { + _, leaderID := r.raft.LeaderWithID() + + return string(leaderID) +} + +func (r *raft) Servers() ([]Server, error) { + future := r.raft.GetConfiguration() + if err := future.Error(); err != nil { + return nil, fmt.Errorf("failed to get raft configuration: %w", err) + } + + leaderID := r.Leader() + + servers := []Server{} + + for _, server := range future.Configuration().Servers { + node := Server{ + ID: string(server.ID), + Address: string(server.Address), + Voter: server.Suffrage == hcraft.Voter, + Leader: string(server.ID) == leaderID, + } + + servers = append(servers, node) + } + + return servers, nil +} + +func (r *raft) Stats() Stats { + stats := Stats{} + + s := r.raft.Stats() + + stats.State = s["state"] + + var lastContactSince time.Duration + + lastContact := s["last_contact"] + if lastContact == "never" { + lastContactSince = time.Since(r.raftStart) + } else { + if d, err := time.ParseDuration(lastContact); err == nil { + lastContactSince = d + } else { + lastContactSince = time.Since(r.raftStart) + } + } + + stats.LastContact = lastContactSince + + if x, err := strconv.ParseUint(s["num_peers"], 10, 64); err == nil { + stats.NumPeers = x + } + + return stats +} + +func (r *raft) Apply(data []byte) error { + future := r.raft.Apply(data, 5*time.Second) + if err := future.Error(); err != nil { + return fmt.Errorf("applying command failed: %w", err) + } + + return nil +} + +func (r *raft) Barrier(timeout time.Duration) error { + future := r.raft.Barrier(timeout) + if err := future.Error(); err != nil { + return fmt.Errorf("error while waiting for the barrier: %w", err) + } + + return nil +} + +func (r *raft) AddServer(id, address string) error { + future := r.raft.AddVoter(hcraft.ServerID(id), hcraft.ServerAddress(address), 0, 0) + if err := future.Error(); err != nil { + return fmt.Errorf("error adding server %s@%s: %w", id, address, err) + } + + return nil +} + +func (r *raft) RemoveServer(id string) error { + future := r.raft.RemoveServer(hcraft.ServerID(id), 0, 0) + if err := future.Error(); err != nil { + return fmt.Errorf("error removing server %s: %w", id, err) + } + + return nil +} + +func (r *raft) LeadershipTransfer() error { + future := r.raft.LeadershipTransfer() + if err := future.Error(); err != nil { + return fmt.Errorf("failed to transfer leadership: %w", err) + } + + return nil +} + +type readCloserWrapper struct { + io.Reader +} + +func (rcw *readCloserWrapper) Read(p []byte) (int, error) { + return rcw.Reader.Read(p) +} + +func (rcw *readCloserWrapper) Close() error { + return nil +} + +type Snapshot struct { + Metadata *hcraft.SnapshotMeta + Data []byte +} + +func (r *raft) Snapshot() (io.ReadCloser, error) { + f := r.raft.Snapshot() + err := f.Error() + if err != nil { + return nil, err + } + + metadata, reader, err := f.Open() + if err != nil { + return nil, err + } + + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed to create snapshot: %w", err) + } + + snapshot := Snapshot{ + Metadata: metadata, + Data: data, + } + + buffer := bytes.Buffer{} + enc := gob.NewEncoder(&buffer) + err = enc.Encode(snapshot) + if err != nil { + return nil, err + } + + return &readCloserWrapper{&buffer}, nil +} + +func (r *raft) start(fsm hcraft.FSM, bootstrap, recover bool, peers []Peer, inmem bool) error { + defer func() { + if r.raft == nil && r.raftStore != nil { + r.raftStore.Close() + } + }() + + addr, err := gonet.ResolveTCPAddr("tcp", r.raftAddress) + if err != nil { + return err + } + + r.logger.Debug().Log("address: %s", addr) + + transport, err := hcraft.NewTCPTransportWithLogger(r.raftAddress, addr, 3, 10*time.Second, raftlogger.New(r.logger, hclog.Debug).Named("raft-transport")) + if err != nil { + return err + } + + r.raftTransport = transport + + snapshotLogger := raftlogger.New(r.logger, hclog.Debug).Named("raft-snapshot") + snapshots, err := hcraft.NewFileSnapshotStoreWithLogger(r.path, 3, snapshotLogger) + if err != nil { + return err + } + + var logStore hcraft.LogStore + var stableStore hcraft.StableStore + if inmem { + logStore = hcraft.NewInmemStore() + stableStore = hcraft.NewInmemStore() + } else { + bolt, err := raftboltdb.New(raftboltdb.Options{ + Path: filepath.Join(r.path, "raftlog.db"), + BoltOptions: &bbolt.Options{ + Timeout: 5 * time.Second, + }, + }) + if err != nil { + return fmt.Errorf("bolt: %w", err) + } + logStore = bolt + stableStore = bolt + + cacheStore, err := hcraft.NewLogCache(512, logStore) + if err != nil { + return err + } + logStore = cacheStore + + r.raftStore = bolt + } + + cfg := hcraft.DefaultConfig() + cfg.LocalID = hcraft.ServerID(r.id) + cfg.Logger = raftlogger.New(r.logger, hclog.Debug).Named("raft") + + hasState, err := hcraft.HasExistingState(logStore, stableStore, snapshots) + if err != nil { + return err + } + + if !hasState { + // Bootstrap cluster + servers := []hcraft.Server{ + { + Suffrage: hcraft.Voter, + ID: hcraft.ServerID(r.id), + Address: transport.LocalAddr(), + }, + } + + for _, p := range peers { + servers = append(servers, hcraft.Server{ + Suffrage: hcraft.Voter, + ID: hcraft.ServerID(p.ID), + Address: hcraft.ServerAddress(p.Address), + }) + } + + configuration := hcraft.Configuration{ + Servers: servers, + } + + if err := hcraft.BootstrapCluster(cfg, logStore, stableStore, snapshots, transport, configuration); err != nil { + return fmt.Errorf("bootstrapping cluster: %w", err) + } + + r.logger.Debug().Log("raft node bootstrapped") + } else { + // Recover cluster + fsm, err := store.NewStore() + if err != nil { + return err + } + + servers := []hcraft.Server{ + { + Suffrage: hcraft.Voter, + ID: hcraft.ServerID(r.id), + Address: transport.LocalAddr(), + }, + } + + for _, p := range peers { + servers = append(servers, hcraft.Server{ + Suffrage: hcraft.Voter, + ID: hcraft.ServerID(p.ID), + Address: hcraft.ServerAddress(p.Address), + }) + } + + configuration := hcraft.Configuration{ + Servers: servers, + } + + if err := hcraft.RecoverCluster(cfg, fsm, logStore, stableStore, snapshots, transport, configuration); err != nil { + return fmt.Errorf("recovering cluster: %w", err) + } + + r.logger.Debug().Log("raft node recoverd") + } + + // Set up a channel for reliable leader notifications. + raftNotifyCh := make(chan bool, 10) + cfg.NotifyCh = raftNotifyCh + r.raftNotifyCh = raftNotifyCh + + node, err := hcraft.NewRaft(cfg, fsm, logStore, stableStore, snapshots, transport) + if err != nil { + return fmt.Errorf("creating raft: %w", err) + } + + r.raft = node + + go r.trackLeaderChanges() + go r.monitorLeadership() + + r.logger.Debug().Log("raft started") + + return nil +} + +func (r *raft) monitorLeadership() { + for { + select { + case isLeader := <-r.raftNotifyCh: + r.leaderLock.Lock() + r.isLeader = isLeader + r.leaderLock.Unlock() + + if r.leadershipNotifyCh != nil { + select { + case r.leadershipNotifyCh <- isLeader: + default: + } + } + + r.logger.Debug().WithField("leader", isLeader).Log("leader notification") + case <-r.shutdownCh: + return + } + } +} + +func (r *raft) trackLeaderChanges() { + obsCh := make(chan hcraft.Observation, 16) + observer := hcraft.NewObserver(obsCh, false, func(o *hcraft.Observation) bool { + _, leaderOK := o.Data.(hcraft.LeaderObservation) + + return leaderOK + }) + + r.raft.RegisterObserver(observer) + + for { + select { + case obs := <-obsCh: + if leaderObs, ok := obs.Data.(hcraft.LeaderObservation); ok { + r.logger.Debug().WithFields(log.Fields{ + "id": leaderObs.LeaderID, + "address": leaderObs.LeaderAddr, + }).Log("new leader observation") + + leaderAddress := string(leaderObs.LeaderAddr) + + if r.leaderObservationCh != nil { + select { + case r.leaderObservationCh <- leaderAddress: + default: + } + } + } else { + r.logger.Debug().WithField("type", reflect.TypeOf(obs.Data)).Log("got unknown observation type from raft") + continue + } + case <-r.shutdownCh: + r.raft.DeregisterObserver(observer) + return + } + } +}