From daaa911813417e2ffe1dc019293b0af33cc07627 Mon Sep 17 00:00:00 2001 From: Rutger Wessels Date: Wed, 18 Jun 2025 17:39:48 +0200 Subject: [PATCH 1/3] Add class for parsing SELECT statements This supports two cases: - SELECT * FROM - SELECT * FROM
u JOIN ON ... --- lib/gitlab/database/sharding/parse/select.rb | 165 +++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 lib/gitlab/database/sharding/parse/select.rb diff --git a/lib/gitlab/database/sharding/parse/select.rb b/lib/gitlab/database/sharding/parse/select.rb new file mode 100644 index 00000000000000..3b044ee795fb50 --- /dev/null +++ b/lib/gitlab/database/sharding/parse/select.rb @@ -0,0 +1,165 @@ +# frozen_string_literal: true + +module Gitlab + module Database + module Sharding + module Parse + class Select + NODE_FACTORIES = { + string: ->(str) { + PgQuery::Node.new(string: PgQuery::String.new(sval: str)) + }, + + integer: ->(val) { + PgQuery::Node.new(a_const: PgQuery::A_Const.new( + ival: PgQuery::Integer.new(ival: val), + isnull: false + )) + }, + + column_ref: ->(table_alias, column) { + PgQuery::Node.new(column_ref: PgQuery::ColumnRef.new( + fields: [ + PgQuery::Node.new(string: PgQuery::String.new(sval: table_alias)), + PgQuery::Node.new(string: PgQuery::String.new(sval: column)) + ] + )) + }, + + equality: ->(left_node, right_node) { + PgQuery::Node.new(a_expr: PgQuery::A_Expr.new( + kind: :AEXPR_OP, + name: [PgQuery::Node.new(string: PgQuery::String.new(sval: "="))], + lexpr: left_node, + rexpr: right_node + )) + }, + + and_expr: ->(left_node, right_node) { + PgQuery::Node.new(bool_expr: PgQuery::BoolExpr.new( + boolop: :AND_EXPR, + args: [left_node, right_node] + )) + } + }.freeze + + def initialize(sql, tenant_conditions = {}, table_dictionary: nil) + @sql = sql + @tenant_conditions = tenant_conditions + @table_dictionary = table_dictionary || Tenant::TableDictionary.new + @parsed = nil # Lazy parse + end + + def add_tenant_filter + return @sql if @tenant_conditions.empty? + + # Only parse if we actually need to modify + @parsed ||= PgQuery.parse(@sql) + + modified = false + + # Process each statement + @parsed.tree.stmts.each do |raw_stmt| + stmt = raw_stmt.stmt + next unless stmt.select_stmt + + select_stmt = stmt.select_stmt + tables_to_filter = extract_filterable_tables(select_stmt) + + next if tables_to_filter.empty? || too_complex?(select_stmt) + + add_where_conditions_fast(select_stmt, tables_to_filter) + modified = true + end + + modified ? @parsed.deparse : @sql + end + + private + + # Fast condition building using pre-built factories + def add_where_conditions_fast(select_stmt, tables) + # Build all tenant conditions at once + tenant_conditions = build_tenant_conditions_fast(tables) + return unless tenant_conditions + + select_stmt.where_clause = if select_stmt.where_clause + NODE_FACTORIES[:and_expr].call( + select_stmt.where_clause, + tenant_conditions + ) + else + tenant_conditions + end + end + + def build_tenant_conditions_fast(tables) + conditions = [] + + # Build conditions for each table + tables.each do |table| + @tenant_conditions.each do |column, value| + column_node = NODE_FACTORIES[:column_ref].call(table[:alias], column.to_s) + value_node = NODE_FACTORIES[:integer].call(value) # Assuming integer tenant IDs + equality_node = NODE_FACTORIES[:equality].call(column_node, value_node) + conditions << equality_node + end + end + + # Combine all conditions with AND + combine_with_and_fast(conditions) + end + + def combine_with_and_fast(conditions) + return if conditions.empty? + return conditions.first if conditions.length == 1 + + conditions.reduce do |combined, condition| + NODE_FACTORIES[:and_expr].call(combined, condition) + end + end + + def extract_filterable_tables(select_stmt) + tables = [] + return tables unless select_stmt.from_clause + + select_stmt.from_clause.each do |from_item| + extract_tables_from_node_fast(from_item, tables) + end + + tables + end + + def extract_tables_from_node_fast(node, tables) + if node.range_var + table_name = node.range_var.relname + + return unless @table_dictionary.has_sharding_key?(table_name) + + alias_name = node.range_var.alias&.aliasname || table_name + tables << { name: table_name, alias: alias_name } + + elsif node.join_expr + extract_join_tables_fast(node.join_expr, tables) + end + end + + def extract_join_tables_fast(join_expr, tables) + extract_tables_from_node_fast(join_expr.larg, tables) if join_expr.larg + extract_tables_from_node_fast(join_expr.rarg, tables) if join_expr.rarg + end + + # Fast complexity check - minimal traversal + def too_complex?(select_stmt) + # Quick checks first (most common cases) + return true if select_stmt.with_clause # CTEs + return true if select_stmt.window_clause&.any? # Window functions + + # Only check for subqueries if needed + select_stmt.target_list&.any? { |target| target.res_target&.val&.sub_link } + end + end + end + end + end +end -- GitLab From cbf5e9574235fd686e9b56453ec6b58f91ea2f5b Mon Sep 17 00:00:00 2001 From: Rutger Wessels Date: Tue, 2 Sep 2025 15:43:48 +0200 Subject: [PATCH 2/3] Add naieve dictonary implementation --- lib/gitlab/database/sharding/dictionary.rb | 21 ++ lib/gitlab/database/sharding/parse/select.rb | 232 +++++++++++++++++-- 2 files changed, 228 insertions(+), 25 deletions(-) create mode 100644 lib/gitlab/database/sharding/dictionary.rb diff --git a/lib/gitlab/database/sharding/dictionary.rb b/lib/gitlab/database/sharding/dictionary.rb new file mode 100644 index 00000000000000..087c982807b980 --- /dev/null +++ b/lib/gitlab/database/sharding/dictionary.rb @@ -0,0 +1,21 @@ +# frozen_string_literal: true + +module Gitlab + module Database + module Sharding + class Dictionary + def has_sharding_key?(table_name) + sharding_key(table_name).present? + end + + def sharding_key(table_name) + sharding_keys = Gitlab::Database::Dictionary.entry(table_name).sharding_key || {} + + return if sharding_keys.empty? + + sharding_keys + end + end + end + end +end diff --git a/lib/gitlab/database/sharding/parse/select.rb b/lib/gitlab/database/sharding/parse/select.rb index 3b044ee795fb50..6c67951643fc7e 100644 --- a/lib/gitlab/database/sharding/parse/select.rb +++ b/lib/gitlab/database/sharding/parse/select.rb @@ -40,18 +40,41 @@ class Select boolop: :AND_EXPR, args: [left_node, right_node] )) + }, + + or_expr: ->(left_node, right_node) { + PgQuery::Node.new(bool_expr: PgQuery::BoolExpr.new( + boolop: :OR_EXPR, + args: [left_node, right_node] + )) + }, + + join_expr: ->(join_type, left_arg, right_arg, quals) { + PgQuery::Node.new(join_expr: PgQuery::JoinExpr.new( + jointype: join_type, + larg: left_arg, + rarg: right_arg, + quals: quals + )) + }, + + range_var: ->(table_name, alias_name = nil) { + PgQuery::Node.new(range_var: PgQuery::RangeVar.new( + relname: table_name, + alias: alias_name ? PgQuery::Alias.new(aliasname: alias_name) : nil + )) } }.freeze - def initialize(sql, tenant_conditions = {}, table_dictionary: nil) + def initialize(sql, organization_id, table_dictionary: nil) @sql = sql - @tenant_conditions = tenant_conditions - @table_dictionary = table_dictionary || Tenant::TableDictionary.new + @organization_id = organization_id + @table_dictionary = table_dictionary || ::Gitlab::Database::Sharding::Dictionary.new @parsed = nil # Lazy parse end def add_tenant_filter - return @sql if @tenant_conditions.empty? + return @sql if @organization_id.nil? # Only parse if we actually need to modify @parsed ||= PgQuery.parse(@sql) @@ -68,7 +91,7 @@ def add_tenant_filter next if tables_to_filter.empty? || too_complex?(select_stmt) - add_where_conditions_fast(select_stmt, tables_to_filter) + add_organization_filters(select_stmt, tables_to_filter) modified = true end @@ -77,37 +100,196 @@ def add_tenant_filter private - # Fast condition building using pre-built factories - def add_where_conditions_fast(select_stmt, tables) - # Build all tenant conditions at once - tenant_conditions = build_tenant_conditions_fast(tables) - return unless tenant_conditions + def add_organization_filters(select_stmt, tables) + # Group tables by whether they have single or multiple sharding strategies + single_strategy_tables = [] + multi_strategy_tables = [] + + tables.each do |table| + sharding_info = @table_dictionary.sharding_key(table[:name]) + next unless sharding_info + + if sharding_info.size == 1 + single_strategy_tables << { table: table, sharding_info: sharding_info } + else + multi_strategy_tables << { table: table, sharding_info: sharding_info } + end + end + + # Handle single strategy tables (existing logic) + handle_single_strategy_tables(select_stmt, single_strategy_tables) + + # Handle multi-strategy tables (new OR logic) + handle_multi_strategy_tables(select_stmt, multi_strategy_tables) + end + + def handle_single_strategy_tables(select_stmt, single_strategy_tables) + tables_needing_joins = [] + direct_conditions = [] + + single_strategy_tables.each do |item| + table = item[:table] + sharding_info = item[:sharding_info] + + column_name, sharding_type = sharding_info.first + + case sharding_type + when "organizations" + # Direct WHERE condition + column_node = NODE_FACTORIES[:column_ref].call(table[:alias], "organization_id") + value_node = NODE_FACTORIES[:integer].call(@organization_id) + equality_node = NODE_FACTORIES[:equality].call(column_node, value_node) + direct_conditions << equality_node + when "namespaces", "projects" + # Needs JOIN + tables_needing_joins << { + table: table, + column: column_name, + join_table: sharding_type + } + end + end + + # Add JOINs for single-strategy tables + add_joins_for_organization_filtering(select_stmt, tables_needing_joins) + + # Add direct WHERE conditions for single-strategy tables + add_direct_where_conditions(select_stmt, direct_conditions) + end + + def handle_multi_strategy_tables(select_stmt, multi_strategy_tables) + return if multi_strategy_tables.empty? + + multi_strategy_tables.each do |item| + table = item[:table] + sharding_info = item[:sharding_info] + + # For multi-strategy tables, we need to add LEFT JOINs for each join-based strategy + # and then create an OR condition + join_conditions = [] + direct_conditions = [] + joins_added = [] + + sharding_info.each do |column_name, sharding_type| + case sharding_type + when "organizations" + # Direct condition: table.organization_id = @organization_id + column_node = NODE_FACTORIES[:column_ref].call(table[:alias], "organization_id") + value_node = NODE_FACTORIES[:integer].call(@organization_id) + equality_node = NODE_FACTORIES[:equality].call(column_node, value_node) + direct_conditions << equality_node + when "namespaces", "projects" + # Add LEFT JOIN and create condition + join_alias = "#{sharding_type}_org_join_#{table[:alias]}_#{column_name}" + + # Add the LEFT JOIN + add_left_join_for_multi_strategy(select_stmt, table, column_name, sharding_type, join_alias) + joins_added << join_alias + + # Create condition: join_alias.organization_id = @organization_id + org_column = NODE_FACTORIES[:column_ref].call(join_alias, "organization_id") + org_value = NODE_FACTORIES[:integer].call(@organization_id) + join_condition = NODE_FACTORIES[:equality].call(org_column, org_value) + join_conditions << join_condition + end + end + + # Combine all conditions (direct + join) with OR + all_conditions = direct_conditions + join_conditions + or_condition = combine_with_or(all_conditions) + # Add the OR condition to WHERE clause + add_direct_where_conditions(select_stmt, [or_condition]) if or_condition + end + end + + def add_left_join_for_multi_strategy(select_stmt, table, column_name, join_table_name, join_alias) + # Create the join condition: table.column = join_table.id + left_column = NODE_FACTORIES[:column_ref].call(table[:alias], column_name) + right_column = NODE_FACTORIES[:column_ref].call(join_alias, "id") + join_condition = NODE_FACTORIES[:equality].call(left_column, right_column) + + # Create the join table reference + join_table_node = NODE_FACTORIES[:range_var].call(join_table_name, join_alias) + + # Create LEFT JOIN (instead of INNER JOIN for multi-strategy) + add_join_to_from_clause(select_stmt, join_table_node, join_condition, :JOIN_LEFT) + end + + def add_join_to_from_clause(select_stmt, join_table_node, join_condition, join_type = :JOIN_INNER) + # If there's already a FROM clause, we need to modify it + return unless select_stmt.from_clause && select_stmt.from_clause.any? + + # Get the first (and typically only) FROM item + first_from = select_stmt.from_clause.first + + # Create a new JOIN expression + join_node = NODE_FACTORIES[:join_expr].call( + join_type, + first_from, + join_table_node, + join_condition + ) + + # Replace the first FROM item with our JOIN + select_stmt.from_clause[0] = join_node + end + + def add_joins_for_organization_filtering(select_stmt, tables_needing_joins) + return if tables_needing_joins.empty? + + tables_needing_joins.each do |join_info| + table = join_info[:table] + column = join_info[:column] + join_table = join_info[:join_table] + join_alias = "#{join_table}_org_join_#{table[:alias]}" + + # Create the join condition: table.column = join_table.id + left_column = NODE_FACTORIES[:column_ref].call(table[:alias], column) + right_column = NODE_FACTORIES[:column_ref].call(join_alias, "id") + join_condition = NODE_FACTORIES[:equality].call(left_column, right_column) + + # Add organization_id condition: join_table.organization_id = @organization_id + org_column = NODE_FACTORIES[:column_ref].call(join_alias, "organization_id") + org_value = NODE_FACTORIES[:integer].call(@organization_id) + org_condition = NODE_FACTORIES[:equality].call(org_column, org_value) + + # Combine join conditions + combined_condition = NODE_FACTORIES[:and_expr].call(join_condition, org_condition) + + # Create the join table reference + join_table_node = NODE_FACTORIES[:range_var].call(join_table, join_alias) + + # Add INNER JOIN for single-strategy tables + add_join_to_from_clause(select_stmt, join_table_node, combined_condition, :JOIN_INNER) + end + end + + def add_direct_where_conditions(select_stmt, conditions) + return if conditions.empty? + + # Combine all direct conditions + combined_conditions = combine_with_and_fast(conditions) + return unless combined_conditions + + # Add to WHERE clause select_stmt.where_clause = if select_stmt.where_clause NODE_FACTORIES[:and_expr].call( select_stmt.where_clause, - tenant_conditions + combined_conditions ) else - tenant_conditions + combined_conditions end end - def build_tenant_conditions_fast(tables) - conditions = [] + def combine_with_or(conditions) + return if conditions.empty? + return conditions.first if conditions.length == 1 - # Build conditions for each table - tables.each do |table| - @tenant_conditions.each do |column, value| - column_node = NODE_FACTORIES[:column_ref].call(table[:alias], column.to_s) - value_node = NODE_FACTORIES[:integer].call(value) # Assuming integer tenant IDs - equality_node = NODE_FACTORIES[:equality].call(column_node, value_node) - conditions << equality_node - end + conditions.reduce do |combined, condition| + NODE_FACTORIES[:or_expr].call(combined, condition) end - - # Combine all conditions with AND - combine_with_and_fast(conditions) end def combine_with_and_fast(conditions) -- GitLab From 91919ade3ba7f0f0dc616e95332394f02e214b6a Mon Sep 17 00:00:00 2001 From: Rutger Wessels Date: Tue, 2 Sep 2025 17:45:42 +0200 Subject: [PATCH 3/3] Patch PG::Connection: add organization filters --- .../active_record_organization_filter.rb | 64 +++++++++++++++++++ lib/gitlab/database/sharding/dictionary.rb | 2 +- 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 config/initializers/active_record_organization_filter.rb diff --git a/config/initializers/active_record_organization_filter.rb b/config/initializers/active_record_organization_filter.rb new file mode 100644 index 00000000000000..51d1331c421d0f --- /dev/null +++ b/config/initializers/active_record_organization_filter.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +# config/initializers/global_query_filter.rb + +module PGConnectionPatch + def exec_params(sql, params, result_format = 0, type_map = nil) + modified_sql = add_global_filter(sql) + super(modified_sql, params, result_format, type_map) + end + + def exec(sql) + modified_sql = add_global_filter(sql) + super(modified_sql) + end + + def async_exec(sql) + modified_sql = add_global_filter(sql) + super(modified_sql) + end + + private + + def add_global_filter(sql) + return sql unless Current.organization_assigned + return sql unless sql.is_a?(String) && sql.strip.match?(/\A\s*SELECT/i) + return sql if sql.include?("/*TENANT_FILTERED*/") + + # rubocop:disable Gitlab/AvoidCurrentOrganization -- We actually need it here. + parser = Gitlab::Database::Sharding::Parse::Select.new(sql, Current.organization.id) + # rubocop:enable Gitlab/AvoidCurrentOrganization + filtered_sql = parser.add_tenant_filter + "#{filtered_sql} /*TENANT_FILTERED*/" + rescue StandardError + sql + end +end + +# Patch PG::Connection class globally +module PGConnectionGlobalPatch + include PGConnectionPatch +end + +Rails.application.config.after_initialize do + # Patch the PG::Connection class itself so all new connections get it + PG::Connection.prepend(PGConnectionGlobalPatch) + + # Also patch existing connections + # rubocop:disable Database/MultipleDatabases -- Low level modification + current_conn = ActiveRecord::Base.connection + # rubocop:enable Database/MultipleDatabases + + if current_conn.respond_to?(:raw_connection) && current_conn.raw_connection.is_a?(PG::Connection) + current_conn.raw_connection.singleton_class.prepend(PGConnectionPatch) + end + + # Patch load balancer connections if available + if current_conn.respond_to?(:load_balancer) && current_conn.load_balancer.respond_to?(:pool) + current_conn.load_balancer.pool.connections.each do |conn| + if conn.respond_to?(:raw_connection) && conn.raw_connection.is_a?(PG::Connection) + conn.raw_connection.singleton_class.prepend(PGConnectionPatch) + end + end + end +end diff --git a/lib/gitlab/database/sharding/dictionary.rb b/lib/gitlab/database/sharding/dictionary.rb index 087c982807b980..da2c3b64ab0a41 100644 --- a/lib/gitlab/database/sharding/dictionary.rb +++ b/lib/gitlab/database/sharding/dictionary.rb @@ -9,7 +9,7 @@ def has_sharding_key?(table_name) end def sharding_key(table_name) - sharding_keys = Gitlab::Database::Dictionary.entry(table_name).sharding_key || {} + sharding_keys = Gitlab::Database::Dictionary.entry(table_name)&.sharding_key || {} return if sharding_keys.empty? -- GitLab