[go: up one dir, main page]

Fix AIP parser breaking change

Change-Id: Ic53d3f8e89d18fd23a42591c7a91fb204d78ae3d
Reviewed-on: https://chromium-review.googlesource.com/c/infra/infra/+/7265145
Commit-Queue: Vaghinak Vardanyan <vaghinak@google.com>
Auto-Submit: Konrad Polarczyk <polarczyk@google.com>
Commit-Queue: Piotr Grabski-Gradziński <piotrgg@google.com>
Reviewed-by: Piotr Grabski-Gradziński <piotrgg@google.com>
Cr-Commit-Position: refs/heads/main@{#78099}
diff --git a/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go b/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go
index e31809f..29e2652 100644
--- a/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go
+++ b/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go
@@ -10,7 +10,6 @@
 
 	"cloud.google.com/go/bigquery"
 
-	"go.chromium.org/luci/common/data/aip160"
 	"go.chromium.org/luci/common/testing/truth/assert"
 	"go.chromium.org/luci/common/testing/truth/should"
 
@@ -40,7 +39,7 @@
 LIMIT 11
 OFFSET 0;`),
 		))
-		assert.Loosely(t, query.Parameters[0].Value.(*aip160.Value).Value, should.Equal("RR-001"))
-		assert.Loosely(t, query.Parameters[1].Value.(*aip160.Value).Value, should.Equal("RR-002"))
+		assert.Loosely(t, query.Parameters[0].Value, should.Equal("RR-001"))
+		assert.Loosely(t, query.Parameters[1].Value, should.Equal("RR-002"))
 	})
 }
diff --git a/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go b/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go
index f981d83..96f04a7 100644
--- a/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go
+++ b/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go
@@ -47,21 +47,21 @@
 		switch filterType := lf.GetFilterType().(type) {
 		case *configpb.LabFilter_ProcessAll:
 			if filterType.ProcessAll {
-				conditions = append(conditions, fmt.Sprintf("%s = %s", LabName.Name, qb.Bind(LabName.Name, lf.GetName())))
+				conditions = append(conditions, fmt.Sprintf("%s = %s", LabName.Name, queryutils.Bind(qb, LabName.Name, lf.GetName())))
 			}
 		case *configpb.LabFilter_AllowedHostGroups:
 			groups := filterType.AllowedHostGroups.GetHostGroups()
 			if len(groups) > 0 {
 				// Devices in this lab, but NOT in allowed host groups, should be removed.
-				params := utils.Map(groups, func(h string) string { return qb.Bind(HostGroup.Name, h) })
-				conditions = append(conditions, fmt.Sprintf("(%s = %s AND %s && ARRAY[%s])", LabName.Name, qb.Bind(LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
+				params := utils.Map(groups, func(h string) string { return queryutils.Bind(qb, HostGroup.Name, h) })
+				conditions = append(conditions, fmt.Sprintf("(%s = %s AND %s && ARRAY[%s])", LabName.Name, queryutils.Bind(qb, LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
 			}
 		case *configpb.LabFilter_DeniedHostGroups:
 			groups := filterType.DeniedHostGroups.GetHostGroups()
 			if len(groups) > 0 {
 				// Devices in this lab AND in denied host groups, should be removed.
-				params := utils.Map(groups, func(h string) string { return qb.Bind(HostGroup.Name, h) })
-				conditions = append(conditions, fmt.Sprintf("(%s = %s AND NOT %s && ARRAY[%s])", LabName.Name, qb.Bind(LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
+				params := utils.Map(groups, func(h string) string { return queryutils.Bind(qb, HostGroup.Name, h) })
+				conditions = append(conditions, fmt.Sprintf("(%s = %s AND NOT %s && ARRAY[%s])", LabName.Name, queryutils.Bind(qb, LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
 			}
 		}
 	}
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go
index 04f9994..0a0ab71 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go
@@ -11,6 +11,8 @@
 
 	"go.chromium.org/luci/common/data/aip160"
 	"go.chromium.org/luci/common/errors"
+
+	"go.chromium.org/infra/fleetconsole/internal/utils"
 )
 
 // compositeArgInfo is the info used when the `arg` has composite expression.
@@ -136,8 +138,8 @@
 			return "", fmt.Errorf("missing field for json column %q", column.ExternalName)
 		}
 
-		fullPath := column.jsonFullPath(fieldsToStrings(simple.Restriction.Comparable.Member.Fields)...)
-		fieldToCheck := fmt.Sprintf("(%s #> %s)", column.Name, q.Bind(column.Name, fullPath))
+		fullPath := column.jsonFullPath(extractValuesFromFields(simple.Restriction.Comparable.Member.Fields)...)
+		fieldToCheck := fmt.Sprintf("(%s #> %s)", column.Name, Bind(q, column.Name, fullPath))
 		return fmt.Sprintf("(%s IS NULL OR %s IN ('[]'::jsonb, 'null'::jsonb))", fieldToCheck, fieldToCheck), nil
 	case ColumnTypeArray:
 		return fmt.Sprintf("(%s IS NULL OR cardinality(%s) = 0)", column.Name, column.Name), nil
@@ -188,7 +190,7 @@
 		return q.expressionQuery(restriction.Arg.Composite, &compositeArgInfo{
 			comparator: restriction.Comparator,
 			column:     column,
-			fields:     fieldsToStrings(restriction.Comparable.Member.Fields),
+			fields:     extractValuesFromFields(restriction.Comparable.Member.Fields),
 		}, comparableOverride)
 	}
 
@@ -213,7 +215,7 @@
 	// as labels have multiple values and we basically check the existence of the specified value in the array.
 	// Currently leaving it as `=` to comply with the frontend.
 	if restriction.Comparator == "=" {
-		value, err := q.jsonArrayHasArgValue(restriction.Arg, column, fieldsToStrings(restriction.Comparable.Member.Fields))
+		value, err := q.jsonArrayHasArgValue(restriction.Arg, column, extractValuesFromFields(restriction.Comparable.Member.Fields))
 		if err != nil {
 			return "", errors.Fmt("argument for field %q: %w", column.ExternalName, err)
 		}
@@ -314,7 +316,7 @@
 	if restriction.Comparator == "" {
 		if len(restriction.Comparable.Member.Fields) > 0 {
 			value := restriction.Comparable.Member.Value.Value
-			fields := strings.Join(fieldsToStrings(restriction.Comparable.Member.Fields), ".")
+			fields := strings.Join(extractValuesFromFields(restriction.Comparable.Member.Fields), ".")
 			return fmt.Errorf("fields are not allowed without an operator, try wrapping %s.%s in double quotes: \"%s.%s\"", value, fields, value, fields)
 		}
 		return fmt.Errorf("global properties are not supported without a comparator. Got: %q", restriction.Comparable.Member.Value.Value)
@@ -406,7 +408,7 @@
 		return "", errors.New("fields not implemented yet")
 	}
 
-	return q.Bind(columnName, comparable.Member.Value), nil
+	return Bind(q, columnName, comparable.Member.Value.Value), nil
 }
 
 // Checks whether value exist in the array specified in the json path
@@ -432,17 +434,13 @@
 	fullPath := column.jsonFullPath(fields...)
 	params := make([]string, len(fullPath))
 	for i, field := range fullPath {
-		params[i] = q.Bind("json_path_variable_"+strconv.Itoa(i)+":"+column.Name, field)
+		params[i] = Bind(q, "json_path_variable_"+strconv.Itoa(i)+":"+column.Name, field)
 	}
 
-	value := q.Bind(column.Name, comparable.Member.Value)
+	value := Bind(q, column.Name, comparable.Member.Value.Value)
 	return fmt.Sprintf("%s ? %s", strings.Join(params, " -> "), value), nil
 }
 
-func fieldsToStrings(fields []*aip160.Value) []string {
-	strings := make([]string, len(fields))
-	for i, field := range fields {
-		strings[i] = field.Value
-	}
-	return strings
+func extractValuesFromFields(fields []*aip160.Value) []string {
+	return utils.Map(fields, func(x *aip160.Value) string { return x.Value })
 }
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go
index 74069de..73e0f94 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go
@@ -7,7 +7,6 @@
 import (
 	"testing"
 
-	"go.chromium.org/luci/common/data/aip160"
 	"go.chromium.org/luci/common/testing/ftt"
 	"go.chromium.org/luci/common/testing/truth/assert"
 	"go.chromium.org/luci/common/testing/truth/should"
@@ -39,7 +38,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("dut_state = available", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "dut_state", Value: &aip160.Value{Value: "available"}},
+					{Name: "dut_state", Value: "available"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (dut_state = $1)"))
 			})
@@ -47,7 +46,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("dut_state != available", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "dut_state", Value: &aip160.Value{Value: "available"}},
+					{Name: "dut_state", Value: "available"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (dut_state <> $1)"))
 			})
@@ -55,7 +54,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("test_date > 2020-01-01", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+					{Name: "test_date", Value: "2020-01-01"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date > $1)"))
 			})
@@ -63,7 +62,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("test_date < 2020-01-01", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+					{Name: "test_date", Value: "2020-01-01"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date < $1)"))
 			})
@@ -71,7 +70,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("test_date >= 2020-01-01", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+					{Name: "test_date", Value: "2020-01-01"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date >= $1)"))
 			})
@@ -79,7 +78,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("test_date <= 2020-01-01", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+					{Name: "test_date", Value: "2020-01-01"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date <= $1)"))
 			})
@@ -87,7 +86,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("host_group = \"acs:camerax\"", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "host_group", Value: &aip160.Value{Value: "acs:camerax", Quoted: true}},
+					{Name: "host_group", Value: "acs:camerax"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (host_group @> ARRAY[$1])"))
 			})
@@ -99,7 +98,7 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("dut_state=(something)", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "dut_state", Value: &aip160.Value{Value: "something"}},
+					{Name: "dut_state", Value: "something"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (dut_state = $1)"))
 			})
@@ -124,7 +123,7 @@
 					{Name: "json_path_variable_2:labels", Value: string("values")},
 					{
 						Name:  "labels",
-						Value: &aip160.Value{Value: "OS_TYPE_CROS"},
+						Value: "OS_TYPE_CROS",
 					},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (labels -> $1 -> $2 -> $3 ? $4)"))
@@ -139,17 +138,17 @@
 					{Name: "json_path_variable_2:labels", Value: string("values")},
 					{
 						Name:  "labels",
-						Value: &aip160.Value{Value: "OS_TYPE_CROS"},
+						Value: "OS_TYPE_CROS",
 					},
 					{Name: "json_path_variable_0:labels", Value: "labels"},
 					{Name: "json_path_variable_1:labels", Value: "label-os_type"},
 					{Name: "json_path_variable_2:labels", Value: string("values")},
 					{
 						Name:  "labels",
-						Value: &aip160.Value{Value: "OS_TYPE_LABSTATION"},
+						Value: "OS_TYPE_LABSTATION",
 					},
-					{Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_LEASED"}},
-					{Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_AVAILABLE"}},
+					{Name: "dut_state", Value: "DEVICE_STATE_LEASED"},
+					{Name: "dut_state", Value: "DEVICE_STATE_AVAILABLE"},
 					// "labels", "label-os_type", "values", "OS_TYPE_CROS", "labels", "label-os_type", "values", "OS_TYPE_LABSTATION", "DEVICE_STATE_LEASED", "DEVICE_STATE_AVAILABLE",
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (((labels -> $1 -> $2 -> $3 ? $4) AND (labels -> $5 -> $6 -> $7 ? $8)) AND ((dut_state = $9) AND (dut_state = $10)))"))
@@ -159,9 +158,9 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("dut_state = (a AND b OR c)", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "dut_state", Value: &aip160.Value{Value: "a"}},
-					{Name: "dut_state", Value: &aip160.Value{Value: "b"}},
-					{Name: "dut_state", Value: &aip160.Value{Value: "c"}},
+					{Name: "dut_state", Value: "a"},
+					{Name: "dut_state", Value: "b"},
+					{Name: "dut_state", Value: "c"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE ((dut_state = $1) AND ((dut_state = $2) OR (dut_state = $3)))"))
 			})
@@ -170,9 +169,9 @@
 				q, err := NewQueryBuilder(table).WithWhereClause("host_group = (a AND b OR c)", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "host_group", Value: &aip160.Value{Value: "a"}},
-					{Name: "host_group", Value: &aip160.Value{Value: "b"}},
-					{Name: "host_group", Value: &aip160.Value{Value: "c"}},
+					{Name: "host_group", Value: "a"},
+					{Name: "host_group", Value: "b"},
+					{Name: "host_group", Value: "c"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE ((host_group @> ARRAY[$1]) AND ((host_group @> ARRAY[$2]) OR (host_group @> ARRAY[$3])))"))
 			})
@@ -196,8 +195,8 @@
 					q, err := NewQueryBuilder(table).WithWhereClause("(NOT dut_state OR dut_state = (DEVICE_STATE_LEASED OR DEVICE_STATE_AVAILABLE))", nil)
 					assert.NoErr(t, err)
 					assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-						{Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_LEASED"}},
-						{Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_AVAILABLE"}},
+						{Name: "dut_state", Value: "DEVICE_STATE_LEASED"},
+						{Name: "dut_state", Value: "DEVICE_STATE_AVAILABLE"},
 					}))
 					assert.Loosely(t, q.whereClause, should.Equal("WHERE ((dut_state = '' OR dut_state IS NULL) OR ((dut_state = $1) OR (dut_state = $2)))"))
 				})
@@ -207,7 +206,7 @@
 				q, err := NewQueryBuilder(table).SetSqlLangType(BigQueryLangType).WithWhereClause("dut_name:chromeos", nil)
 				assert.NoErr(t, err)
 				assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
-					{Name: "dut_name", Value: &aip160.Value{Value: "chromeos"}},
+					{Name: "dut_name", Value: "chromeos"},
 				}))
 				assert.Loosely(t, q.whereClause, should.Equal("WHERE (? IN UNNEST(dut_name))"))
 			})
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go b/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go
index 1492b8d..b400bfb 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go
@@ -81,7 +81,7 @@
 			fullPath := column.jsonFullPath(fields...)
 			params := make([]string, len(fullPath))
 			for i, field := range fullPath {
-				params[i] = q.Bind(column.Name, field)
+				params[i] = Bind(q, column.Name, field)
 			}
 
 			// In case of arrays in json it will order by based on the first element
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/query.go b/go/src/infra/fleetconsole/internal/database/queryutils/query.go
index f462c50..cab8f12 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/query.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/query.go
@@ -11,6 +11,14 @@
 	"strings"
 )
 
+// This is a type validator for values of the query param
+// and it helps to identify unsupported types in the compile time
+// ~X: This means any type whose underlying type is X
+// | is the OR operator
+type ParameterValueConstraint interface {
+	~string | ~[]string
+}
+
 type QueryParameter struct {
 	Name  string
 	Value any
@@ -230,7 +238,7 @@
 
 	params := make([]string, len(values))
 	for i, value := range values {
-		params[i] = q.Bind(column, value)
+		params[i] = Bind(q, column, value)
 	}
 
 	if len(params) > 0 {
@@ -253,7 +261,9 @@
 // Bind binds a new query parameter with the given value, and returns
 // the name of the parameter.
 // The returned string is an injection-safe SQL expression.
-func (q *QueryBuilder) Bind(columnName string, value any) string {
+// Type constraint is mainly to identify not supported types passed
+// in compile time.
+func Bind[T ParameterValueConstraint](q *QueryBuilder, columnName string, value T) string {
 	var name string
 	if q.sqlLangType == BigQueryLangType {
 		name = "?"
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go b/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go
index a29168e..03303a1 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go
@@ -7,7 +7,6 @@
 import (
 	"testing"
 
-	"go.chromium.org/luci/common/data/aip160"
 	"go.chromium.org/luci/common/testing/ftt"
 	"go.chromium.org/luci/common/testing/truth/assert"
 	"go.chromium.org/luci/common/testing/truth/should"
@@ -41,7 +40,7 @@
 			q, err := builder.Build([]string{"my-realm"})
 			assert.NoErr(t, err)
 			assert.Loosely(t, q.Parameters, should.Match([]any{
-				&aip160.Value{Value: "available"},
+				"available",
 				"my-realm",
 			}))
 			assert.Loosely(t, testutils.NormalizeSQL(q.Statement), should.Equal("SELECT id, dut_state, dut_name, labels, realm FROM \"Devices\" WHERE (dut_state = $1) AND (realm IN ($2) OR realm IS NULL) GROUP BY id ORDER BY id DESC LIMIT 10 OFFSET 20;"))
@@ -60,7 +59,7 @@
 			q, err := builder.Build([]string{"my-realm"})
 			assert.NoErr(t, err)
 			assert.Loosely(t, q.Parameters, should.Match([]any{
-				&aip160.Value{Value: "available"},
+				"available",
 				"my-realm",
 			}))
 			assert.Loosely(t, testutils.NormalizeSQL(q.Statement), should.Equal("SELECT id, dut_state, dut_name, labels, realm FROM `Devices` WHERE (dut_state = ?) AND (realm IN (?) OR realm IS NULL) GROUP BY id ORDER BY id DESC LIMIT 10 OFFSET 20;"))