diff --git a/config/initializers/active_record_organization_filter.rb b/config/initializers/active_record_organization_filter.rb new file mode 100644 index 0000000000000000000000000000000000000000..51d1331c421d0fa3b60e341dda80f89aab18114a --- /dev/null +++ b/config/initializers/active_record_organization_filter.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +# config/initializers/global_query_filter.rb + +module PGConnectionPatch + def exec_params(sql, params, result_format = 0, type_map = nil) + modified_sql = add_global_filter(sql) + super(modified_sql, params, result_format, type_map) + end + + def exec(sql) + modified_sql = add_global_filter(sql) + super(modified_sql) + end + + def async_exec(sql) + modified_sql = add_global_filter(sql) + super(modified_sql) + end + + private + + def add_global_filter(sql) + return sql unless Current.organization_assigned + return sql unless sql.is_a?(String) && sql.strip.match?(/\A\s*SELECT/i) + return sql if sql.include?("/*TENANT_FILTERED*/") + + # rubocop:disable Gitlab/AvoidCurrentOrganization -- We actually need it here. + parser = Gitlab::Database::Sharding::Parse::Select.new(sql, Current.organization.id) + # rubocop:enable Gitlab/AvoidCurrentOrganization + filtered_sql = parser.add_tenant_filter + "#{filtered_sql} /*TENANT_FILTERED*/" + rescue StandardError + sql + end +end + +# Patch PG::Connection class globally +module PGConnectionGlobalPatch + include PGConnectionPatch +end + +Rails.application.config.after_initialize do + # Patch the PG::Connection class itself so all new connections get it + PG::Connection.prepend(PGConnectionGlobalPatch) + + # Also patch existing connections + # rubocop:disable Database/MultipleDatabases -- Low level modification + current_conn = ActiveRecord::Base.connection + # rubocop:enable Database/MultipleDatabases + + if current_conn.respond_to?(:raw_connection) && current_conn.raw_connection.is_a?(PG::Connection) + current_conn.raw_connection.singleton_class.prepend(PGConnectionPatch) + end + + # Patch load balancer connections if available + if current_conn.respond_to?(:load_balancer) && current_conn.load_balancer.respond_to?(:pool) + current_conn.load_balancer.pool.connections.each do |conn| + if conn.respond_to?(:raw_connection) && conn.raw_connection.is_a?(PG::Connection) + conn.raw_connection.singleton_class.prepend(PGConnectionPatch) + end + end + end +end diff --git a/lib/gitlab/database/sharding/dictionary.rb b/lib/gitlab/database/sharding/dictionary.rb new file mode 100644 index 0000000000000000000000000000000000000000..da2c3b64ab0a41001f25bc7132ed77b288ccbbfb --- /dev/null +++ b/lib/gitlab/database/sharding/dictionary.rb @@ -0,0 +1,21 @@ +# frozen_string_literal: true + +module Gitlab + module Database + module Sharding + class Dictionary + def has_sharding_key?(table_name) + sharding_key(table_name).present? + end + + def sharding_key(table_name) + sharding_keys = Gitlab::Database::Dictionary.entry(table_name)&.sharding_key || {} + + return if sharding_keys.empty? + + sharding_keys + end + end + end + end +end diff --git a/lib/gitlab/database/sharding/parse/select.rb b/lib/gitlab/database/sharding/parse/select.rb new file mode 100644 index 0000000000000000000000000000000000000000..6c67951643fc7e02849f15c6af1de50047147649 --- /dev/null +++ b/lib/gitlab/database/sharding/parse/select.rb @@ -0,0 +1,347 @@ +# frozen_string_literal: true + +module Gitlab + module Database + module Sharding + module Parse + class Select + NODE_FACTORIES = { + string: ->(str) { + PgQuery::Node.new(string: PgQuery::String.new(sval: str)) + }, + + integer: ->(val) { + PgQuery::Node.new(a_const: PgQuery::A_Const.new( + ival: PgQuery::Integer.new(ival: val), + isnull: false + )) + }, + + column_ref: ->(table_alias, column) { + PgQuery::Node.new(column_ref: PgQuery::ColumnRef.new( + fields: [ + PgQuery::Node.new(string: PgQuery::String.new(sval: table_alias)), + PgQuery::Node.new(string: PgQuery::String.new(sval: column)) + ] + )) + }, + + equality: ->(left_node, right_node) { + PgQuery::Node.new(a_expr: PgQuery::A_Expr.new( + kind: :AEXPR_OP, + name: [PgQuery::Node.new(string: PgQuery::String.new(sval: "="))], + lexpr: left_node, + rexpr: right_node + )) + }, + + and_expr: ->(left_node, right_node) { + PgQuery::Node.new(bool_expr: PgQuery::BoolExpr.new( + boolop: :AND_EXPR, + args: [left_node, right_node] + )) + }, + + or_expr: ->(left_node, right_node) { + PgQuery::Node.new(bool_expr: PgQuery::BoolExpr.new( + boolop: :OR_EXPR, + args: [left_node, right_node] + )) + }, + + join_expr: ->(join_type, left_arg, right_arg, quals) { + PgQuery::Node.new(join_expr: PgQuery::JoinExpr.new( + jointype: join_type, + larg: left_arg, + rarg: right_arg, + quals: quals + )) + }, + + range_var: ->(table_name, alias_name = nil) { + PgQuery::Node.new(range_var: PgQuery::RangeVar.new( + relname: table_name, + alias: alias_name ? PgQuery::Alias.new(aliasname: alias_name) : nil + )) + } + }.freeze + + def initialize(sql, organization_id, table_dictionary: nil) + @sql = sql + @organization_id = organization_id + @table_dictionary = table_dictionary || ::Gitlab::Database::Sharding::Dictionary.new + @parsed = nil # Lazy parse + end + + def add_tenant_filter + return @sql if @organization_id.nil? + + # Only parse if we actually need to modify + @parsed ||= PgQuery.parse(@sql) + + modified = false + + # Process each statement + @parsed.tree.stmts.each do |raw_stmt| + stmt = raw_stmt.stmt + next unless stmt.select_stmt + + select_stmt = stmt.select_stmt + tables_to_filter = extract_filterable_tables(select_stmt) + + next if tables_to_filter.empty? || too_complex?(select_stmt) + + add_organization_filters(select_stmt, tables_to_filter) + modified = true + end + + modified ? @parsed.deparse : @sql + end + + private + + def add_organization_filters(select_stmt, tables) + # Group tables by whether they have single or multiple sharding strategies + single_strategy_tables = [] + multi_strategy_tables = [] + + tables.each do |table| + sharding_info = @table_dictionary.sharding_key(table[:name]) + next unless sharding_info + + if sharding_info.size == 1 + single_strategy_tables << { table: table, sharding_info: sharding_info } + else + multi_strategy_tables << { table: table, sharding_info: sharding_info } + end + end + + # Handle single strategy tables (existing logic) + handle_single_strategy_tables(select_stmt, single_strategy_tables) + + # Handle multi-strategy tables (new OR logic) + handle_multi_strategy_tables(select_stmt, multi_strategy_tables) + end + + def handle_single_strategy_tables(select_stmt, single_strategy_tables) + tables_needing_joins = [] + direct_conditions = [] + + single_strategy_tables.each do |item| + table = item[:table] + sharding_info = item[:sharding_info] + + column_name, sharding_type = sharding_info.first + + case sharding_type + when "organizations" + # Direct WHERE condition + column_node = NODE_FACTORIES[:column_ref].call(table[:alias], "organization_id") + value_node = NODE_FACTORIES[:integer].call(@organization_id) + equality_node = NODE_FACTORIES[:equality].call(column_node, value_node) + direct_conditions << equality_node + when "namespaces", "projects" + # Needs JOIN + tables_needing_joins << { + table: table, + column: column_name, + join_table: sharding_type + } + end + end + + # Add JOINs for single-strategy tables + add_joins_for_organization_filtering(select_stmt, tables_needing_joins) + + # Add direct WHERE conditions for single-strategy tables + add_direct_where_conditions(select_stmt, direct_conditions) + end + + def handle_multi_strategy_tables(select_stmt, multi_strategy_tables) + return if multi_strategy_tables.empty? + + multi_strategy_tables.each do |item| + table = item[:table] + sharding_info = item[:sharding_info] + + # For multi-strategy tables, we need to add LEFT JOINs for each join-based strategy + # and then create an OR condition + join_conditions = [] + direct_conditions = [] + joins_added = [] + + sharding_info.each do |column_name, sharding_type| + case sharding_type + when "organizations" + # Direct condition: table.organization_id = @organization_id + column_node = NODE_FACTORIES[:column_ref].call(table[:alias], "organization_id") + value_node = NODE_FACTORIES[:integer].call(@organization_id) + equality_node = NODE_FACTORIES[:equality].call(column_node, value_node) + direct_conditions << equality_node + when "namespaces", "projects" + # Add LEFT JOIN and create condition + join_alias = "#{sharding_type}_org_join_#{table[:alias]}_#{column_name}" + + # Add the LEFT JOIN + add_left_join_for_multi_strategy(select_stmt, table, column_name, sharding_type, join_alias) + joins_added << join_alias + + # Create condition: join_alias.organization_id = @organization_id + org_column = NODE_FACTORIES[:column_ref].call(join_alias, "organization_id") + org_value = NODE_FACTORIES[:integer].call(@organization_id) + join_condition = NODE_FACTORIES[:equality].call(org_column, org_value) + join_conditions << join_condition + end + end + + # Combine all conditions (direct + join) with OR + all_conditions = direct_conditions + join_conditions + or_condition = combine_with_or(all_conditions) + + # Add the OR condition to WHERE clause + add_direct_where_conditions(select_stmt, [or_condition]) if or_condition + end + end + + def add_left_join_for_multi_strategy(select_stmt, table, column_name, join_table_name, join_alias) + # Create the join condition: table.column = join_table.id + left_column = NODE_FACTORIES[:column_ref].call(table[:alias], column_name) + right_column = NODE_FACTORIES[:column_ref].call(join_alias, "id") + join_condition = NODE_FACTORIES[:equality].call(left_column, right_column) + + # Create the join table reference + join_table_node = NODE_FACTORIES[:range_var].call(join_table_name, join_alias) + + # Create LEFT JOIN (instead of INNER JOIN for multi-strategy) + add_join_to_from_clause(select_stmt, join_table_node, join_condition, :JOIN_LEFT) + end + + def add_join_to_from_clause(select_stmt, join_table_node, join_condition, join_type = :JOIN_INNER) + # If there's already a FROM clause, we need to modify it + return unless select_stmt.from_clause && select_stmt.from_clause.any? + + # Get the first (and typically only) FROM item + first_from = select_stmt.from_clause.first + + # Create a new JOIN expression + join_node = NODE_FACTORIES[:join_expr].call( + join_type, + first_from, + join_table_node, + join_condition + ) + + # Replace the first FROM item with our JOIN + select_stmt.from_clause[0] = join_node + end + + def add_joins_for_organization_filtering(select_stmt, tables_needing_joins) + return if tables_needing_joins.empty? + + tables_needing_joins.each do |join_info| + table = join_info[:table] + column = join_info[:column] + join_table = join_info[:join_table] + join_alias = "#{join_table}_org_join_#{table[:alias]}" + + # Create the join condition: table.column = join_table.id + left_column = NODE_FACTORIES[:column_ref].call(table[:alias], column) + right_column = NODE_FACTORIES[:column_ref].call(join_alias, "id") + join_condition = NODE_FACTORIES[:equality].call(left_column, right_column) + + # Add organization_id condition: join_table.organization_id = @organization_id + org_column = NODE_FACTORIES[:column_ref].call(join_alias, "organization_id") + org_value = NODE_FACTORIES[:integer].call(@organization_id) + org_condition = NODE_FACTORIES[:equality].call(org_column, org_value) + + # Combine join conditions + combined_condition = NODE_FACTORIES[:and_expr].call(join_condition, org_condition) + + # Create the join table reference + join_table_node = NODE_FACTORIES[:range_var].call(join_table, join_alias) + + # Add INNER JOIN for single-strategy tables + add_join_to_from_clause(select_stmt, join_table_node, combined_condition, :JOIN_INNER) + end + end + + def add_direct_where_conditions(select_stmt, conditions) + return if conditions.empty? + + # Combine all direct conditions + combined_conditions = combine_with_and_fast(conditions) + return unless combined_conditions + + # Add to WHERE clause + select_stmt.where_clause = if select_stmt.where_clause + NODE_FACTORIES[:and_expr].call( + select_stmt.where_clause, + combined_conditions + ) + else + combined_conditions + end + end + + def combine_with_or(conditions) + return if conditions.empty? + return conditions.first if conditions.length == 1 + + conditions.reduce do |combined, condition| + NODE_FACTORIES[:or_expr].call(combined, condition) + end + end + + def combine_with_and_fast(conditions) + return if conditions.empty? + return conditions.first if conditions.length == 1 + + conditions.reduce do |combined, condition| + NODE_FACTORIES[:and_expr].call(combined, condition) + end + end + + def extract_filterable_tables(select_stmt) + tables = [] + return tables unless select_stmt.from_clause + + select_stmt.from_clause.each do |from_item| + extract_tables_from_node_fast(from_item, tables) + end + + tables + end + + def extract_tables_from_node_fast(node, tables) + if node.range_var + table_name = node.range_var.relname + + return unless @table_dictionary.has_sharding_key?(table_name) + + alias_name = node.range_var.alias&.aliasname || table_name + tables << { name: table_name, alias: alias_name } + + elsif node.join_expr + extract_join_tables_fast(node.join_expr, tables) + end + end + + def extract_join_tables_fast(join_expr, tables) + extract_tables_from_node_fast(join_expr.larg, tables) if join_expr.larg + extract_tables_from_node_fast(join_expr.rarg, tables) if join_expr.rarg + end + + # Fast complexity check - minimal traversal + def too_complex?(select_stmt) + # Quick checks first (most common cases) + return true if select_stmt.with_clause # CTEs + return true if select_stmt.window_clause&.any? # Window functions + + # Only check for subqueries if needed + select_stmt.target_list&.any? { |target| target.res_target&.val&.sub_link } + end + end + end + end + end +end