diff --git a/lib/gitlab/database/load_balancing/sticking.rb b/lib/gitlab/database/load_balancing/sticking.rb index c9610a74bc890e562fed3d7aa3c58dc3795b622c..955e3b459ced31f37f9189213d2f8c3744e19b15 100644 --- a/lib/gitlab/database/load_balancing/sticking.rb +++ b/lib/gitlab/database/load_balancing/sticking.rb @@ -10,6 +10,19 @@ class Sticking # the primary. EXPIRATION = 30 + UNSTICK_IF_CAUGHT_UP_SCRIPT = <<~LUA + local key = KEYS[1] + local expected_location = ARGV[1] + local current_location = redis.call('GET', key) + + if current_location == expected_location then + redis.call('DEL', key) + return 1 + else + return 0 + end + LUA + attr_reader :load_balancer def initialize(load_balancer) @@ -35,7 +48,7 @@ def find_caught_up_replica( result = if location up_to_date_result = @load_balancer.select_up_to_date_host(location) - unstick(namespace, id) if up_to_date_result == LoadBalancer::ALL_CAUGHT_UP + unstick_if_caught_up(namespace, id, location) if up_to_date_result == LoadBalancer::ALL_CAUGHT_UP up_to_date_result != LoadBalancer::NONE_CAUGHT_UP else @@ -100,6 +113,17 @@ def unstick(namespace, id) end end + # Atomically unstick only if the sticking point hasn't changed since we read it. + # This prevents a race condition where a concurrent request sets a new sticking point + # after we've verified all replicas are caught up but before we unstick. + # + # Returns 1 if unstick was performed, 0 if the value changed (indicating a new write). + def unstick_if_caught_up(namespace, id, expected_location) + with_redis do |redis| + redis.eval(UNSTICK_IF_CAUGHT_UP_SCRIPT, keys: [redis_key_for(namespace, id)], argv: [expected_location]) + end + end + def set_write_location_for(namespace, id, location) with_redis do |redis| redis.set(redis_key_for(namespace, id), location, ex: EXPIRATION) diff --git a/spec/lib/gitlab/database/load_balancing/sticking_spec.rb b/spec/lib/gitlab/database/load_balancing/sticking_spec.rb index d7ad34357605afbd0e5dfe082a559f5e523cfae0..116bbe9a4f54869d06bd40ed3cf9fdbe721e5b35 100644 --- a/spec/lib/gitlab/database/load_balancing/sticking_spec.rb +++ b/spec/lib/gitlab/database/load_balancing/sticking_spec.rb @@ -57,16 +57,39 @@ end context 'when all replicas have caught up' do - it 'returns true and unsticks' do + it 'returns true and attempts to unstick if location matches' do expect(load_balancer).to receive(:select_up_to_date_host).with(last_write_location) .and_return(::Gitlab::Database::LoadBalancing::LoadBalancer::ALL_CAUGHT_UP) expect(redis) - .to receive(:del) - .with("database-load-balancing/write-location/#{load_balancer.name}/user/42") + .to receive(:eval) + .with( + described_class::UNSTICK_IF_CAUGHT_UP_SCRIPT, + keys: ["database-load-balancing/write-location/#{load_balancer.name}/user/42"], + argv: [last_write_location] + ) + .and_return(1) expect(sticking.find_caught_up_replica(:user, 42)).to eq(true) end + + context 'when the sticking point has changed (concurrent write)' do + it 'returns true but does not unstick' do + expect(load_balancer).to receive(:select_up_to_date_host).with(last_write_location) + .and_return(::Gitlab::Database::LoadBalancing::LoadBalancer::ALL_CAUGHT_UP) + + expect(redis) + .to receive(:eval) + .with( + described_class::UNSTICK_IF_CAUGHT_UP_SCRIPT, + keys: ["database-load-balancing/write-location/#{load_balancer.name}/user/42"], + argv: [last_write_location] + ) + .and_return(0) + + expect(sticking.find_caught_up_replica(:user, 42)).to eq(true) + end + end end context 'when only some of the replicas have caught up' do