diff --git a/internal/gitaly/storage/raftmgr/raft_enabled_storage.go b/internal/gitaly/storage/raftmgr/raft_enabled_storage.go index a58296690008836029d98bcafd6a8312353d3e7b..e1c89026c90ccde96bbadd0da3ebebd143d9cd6c 100644 --- a/internal/gitaly/storage/raftmgr/raft_enabled_storage.go +++ b/internal/gitaly/storage/raftmgr/raft_enabled_storage.go @@ -13,6 +13,7 @@ import ( // RaftEnabledStorage wraps a storage.Storage instance with Raft functionality type RaftEnabledStorage struct { storage.Storage + address string transport Transport routingTable RoutingTable replicaRegistry ReplicaRegistry @@ -33,6 +34,11 @@ func (s *RaftEnabledStorage) GetReplicaRegistry() ReplicaRegistry { return s.replicaRegistry } +// GetNodeAddress returns the node's address. +func (s *RaftEnabledStorage) GetNodeAddress() string { + return s.address +} + // RegisterReplica registers a replica with this RaftEnabledStorage // This should be called after both the replica and RaftEnabledStorage are created func (s *RaftEnabledStorage) RegisterReplica(replica *Replica) error { @@ -74,11 +80,17 @@ func NewNode(cfg config.Cfg, logger log.Logger, dbMgr *databasemgr.DBManager, co replicaRegistry := NewReplicaRegistry() transport := NewGrpcTransport(logger, cfg, routingTable, replicaRegistry, connsPool) + address, err := cfg.GetAddressWithScheme() + if err != nil { + return nil, fmt.Errorf("get address with scheme: %w", err) + } + n.storages[cfgStorage.Name] = &RaftEnabledStorage{ Storage: baseStorage, // storage.Storage would be nil initially transport: transport, routingTable: routingTable, replicaRegistry: replicaRegistry, + address: address, } } diff --git a/internal/gitaly/storage/raftmgr/replica.go b/internal/gitaly/storage/raftmgr/replica.go index ae8ef0c8128306bde5672e137ea97748b883d49f..c0bff5eaa1293fb564e6ba1c4c5e9572439d0d2f 100644 --- a/internal/gitaly/storage/raftmgr/replica.go +++ b/internal/gitaly/storage/raftmgr/replica.go @@ -3,6 +3,7 @@ package raftmgr import ( "context" "crypto/sha256" + "encoding/json" "errors" "fmt" "runtime" @@ -414,7 +415,20 @@ func (replica *Replica) Initialize(ctx context.Context, appliedLSN storage.LSN) switch initStatus { case InitStatusUnbootstrapped: // For first-time bootstrap, initialize with self as the only peer - peers := []raft.Peer{{ID: replica.memberID}} + nodeAddress := replica.raftEnabledStorage.GetNodeAddress() + + metadata := &gitalypb.ReplicaID_Metadata{ + Address: nodeAddress, + } + ctx := ConfChangeContext{ + Metadata: metadata, + } + contextBytes, err := json.Marshal(ctx) + if err != nil { + return fmt.Errorf("marshalling conf change context: %w", err) + } + peers := []raft.Peer{{ID: replica.memberID, Context: contextBytes}} + replica.node = raft.StartNode(config, peers) case InitStatusBootstrapped: // For restarts, set Applied to latest committed LSN diff --git a/internal/gitaly/storage/raftmgr/replica_test.go b/internal/gitaly/storage/raftmgr/replica_test.go index c9263ff7c353adf2d3f2b31b5d86791588738f61..81b24f1d7a9941a1df625f1791f36a536dc9bc0b 100644 --- a/internal/gitaly/storage/raftmgr/replica_test.go +++ b/internal/gitaly/storage/raftmgr/replica_test.go @@ -1729,6 +1729,10 @@ func TestReplica_AddNode(t *testing.T) { transport := NewGrpcTransport(logger, cfg, routingTable, registry, nil) socketPath, srv := createTempServer(t, transport) + cfg.SocketPath = socketPath + + address, err := cfg.GetAddressWithScheme() + require.NoError(t, err) replica, err := createRaftReplica(t, ctx, memberID, socketPath, raftCfg, partitionID, metrics, opts...) require.NoError(t, err) @@ -1741,7 +1745,7 @@ func TestReplica_AddNode(t *testing.T) { err = replica.Initialize(ctx, 0) require.NoError(t, err) - return replica, socketPath, srv + return replica, address, srv } // waitForLeadership waits for the replica to become leader @@ -1783,8 +1787,7 @@ func TestReplica_AddNode(t *testing.T) { metrics := NewMetrics() partitionID := storage.PartitionID(1) - replica, socketPath, srv := createTestNode(t, ctx, 1, partitionID, raftCfg, metrics) - replicaAddress := "unix://" + socketPath + replica, replicaAddress, srv := createTestNode(t, ctx, 1, partitionID, raftCfg, metrics) defer func() { srv.Stop() require.NoError(t, replica.Close()) @@ -1799,9 +1802,21 @@ func TestReplica_AddNode(t *testing.T) { routingTable := raftEnabledStorage.GetRoutingTable() partitionKey := replica.partitionKey + // verify the routing table is updated + require.Eventually(t, func() bool { + entry, err := routingTable.GetEntry(partitionKey) + if err != nil { + return false + } + leaderEntry := entry.Replicas[0] + require.Equal(t, uint64(1), leaderEntry.GetMemberId()) + require.Equal(t, replicaAddress, leaderEntry.GetMetadata().GetAddress()) + require.Equal(t, replica.logStore.storageName, leaderEntry.GetStorageName()) + return len(entry.Replicas) == 1 + }, waitTimeout, 5*time.Millisecond, "routing table should be updated") + // Create second node - replicaTwo, socketPathTwo, srvTwo := createTestNode(t, ctx, 3, partitionID, raftCfg, metrics) - replicaTwoAddress := "unix://" + socketPathTwo + replicaTwo, replicaTwoAddress, srvTwo := createTestNode(t, ctx, 3, partitionID, raftCfg, metrics) defer func() { srvTwo.Stop() require.NoError(t, replicaTwo.Close()) @@ -1870,14 +1885,12 @@ func TestReplica_AddNode(t *testing.T) { metrics := NewMetrics() partitionID := storage.PartitionID(1) - replica, socketPath, srv := createTestNode(t, ctx, 1, partitionID, raftCfg, metrics) + replica, replicaOneAddress, srv := createTestNode(t, ctx, 1, partitionID, raftCfg, metrics) defer func() { srv.Stop() require.NoError(t, replica.Close()) }() - replicaOneAddress := "unix://" + socketPath - waitForLeadership(t, replica, waitTimeout) raftEnabledStorage := replica.raftEnabledStorage @@ -1904,10 +1917,9 @@ func TestReplica_AddNode(t *testing.T) { for i := uint64(2); i <= lastMemberID; i++ { // create multiple replicas with new addresses - replica, socketPath, srv := createTestNode(t, ctx, i, partitionID, raftCfg, metrics) + replica, address, srv := createTestNode(t, ctx, i, partitionID, raftCfg, metrics) servers = append(servers, srv) - address := "unix://" + socketPath destinationAddresses = append(destinationAddresses, address) addressesToReplicas[address] = replica