Fix AIP parser breaking change
Change-Id: Ic53d3f8e89d18fd23a42591c7a91fb204d78ae3d
Reviewed-on: https://chromium-review.googlesource.com/c/infra/infra/+/7265145
Commit-Queue: Vaghinak Vardanyan <vaghinak@google.com>
Auto-Submit: Konrad Polarczyk <polarczyk@google.com>
Commit-Queue: Piotr Grabski-Gradziński <piotrgg@google.com>
Reviewed-by: Piotr Grabski-Gradziński <piotrgg@google.com>
Cr-Commit-Position: refs/heads/main@{#78099}
diff --git a/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go b/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go
index e31809f..29e2652 100644
--- a/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go
+++ b/go/src/infra/fleetconsole/internal/consoleserver/list_resource_requests_test.go
@@ -10,7 +10,6 @@
"cloud.google.com/go/bigquery"
- "go.chromium.org/luci/common/data/aip160"
"go.chromium.org/luci/common/testing/truth/assert"
"go.chromium.org/luci/common/testing/truth/should"
@@ -40,7 +39,7 @@
LIMIT 11
OFFSET 0;`),
))
- assert.Loosely(t, query.Parameters[0].Value.(*aip160.Value).Value, should.Equal("RR-001"))
- assert.Loosely(t, query.Parameters[1].Value.(*aip160.Value).Value, should.Equal("RR-002"))
+ assert.Loosely(t, query.Parameters[0].Value, should.Equal("RR-001"))
+ assert.Loosely(t, query.Parameters[1].Value, should.Equal("RR-002"))
})
}
diff --git a/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go b/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go
index f981d83..96f04a7 100644
--- a/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go
+++ b/go/src/infra/fleetconsole/internal/database/android_devicedb/devicecleanup.go
@@ -47,21 +47,21 @@
switch filterType := lf.GetFilterType().(type) {
case *configpb.LabFilter_ProcessAll:
if filterType.ProcessAll {
- conditions = append(conditions, fmt.Sprintf("%s = %s", LabName.Name, qb.Bind(LabName.Name, lf.GetName())))
+ conditions = append(conditions, fmt.Sprintf("%s = %s", LabName.Name, queryutils.Bind(qb, LabName.Name, lf.GetName())))
}
case *configpb.LabFilter_AllowedHostGroups:
groups := filterType.AllowedHostGroups.GetHostGroups()
if len(groups) > 0 {
// Devices in this lab, but NOT in allowed host groups, should be removed.
- params := utils.Map(groups, func(h string) string { return qb.Bind(HostGroup.Name, h) })
- conditions = append(conditions, fmt.Sprintf("(%s = %s AND %s && ARRAY[%s])", LabName.Name, qb.Bind(LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
+ params := utils.Map(groups, func(h string) string { return queryutils.Bind(qb, HostGroup.Name, h) })
+ conditions = append(conditions, fmt.Sprintf("(%s = %s AND %s && ARRAY[%s])", LabName.Name, queryutils.Bind(qb, LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
}
case *configpb.LabFilter_DeniedHostGroups:
groups := filterType.DeniedHostGroups.GetHostGroups()
if len(groups) > 0 {
// Devices in this lab AND in denied host groups, should be removed.
- params := utils.Map(groups, func(h string) string { return qb.Bind(HostGroup.Name, h) })
- conditions = append(conditions, fmt.Sprintf("(%s = %s AND NOT %s && ARRAY[%s])", LabName.Name, qb.Bind(LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
+ params := utils.Map(groups, func(h string) string { return queryutils.Bind(qb, HostGroup.Name, h) })
+ conditions = append(conditions, fmt.Sprintf("(%s = %s AND NOT %s && ARRAY[%s])", LabName.Name, queryutils.Bind(qb, LabName.Name, lf.GetName()), HostGroup.Name, strings.Join(params, ", ")))
}
}
}
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go
index 04f9994..0a0ab71 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator.go
@@ -11,6 +11,8 @@
"go.chromium.org/luci/common/data/aip160"
"go.chromium.org/luci/common/errors"
+
+ "go.chromium.org/infra/fleetconsole/internal/utils"
)
// compositeArgInfo is the info used when the `arg` has composite expression.
@@ -136,8 +138,8 @@
return "", fmt.Errorf("missing field for json column %q", column.ExternalName)
}
- fullPath := column.jsonFullPath(fieldsToStrings(simple.Restriction.Comparable.Member.Fields)...)
- fieldToCheck := fmt.Sprintf("(%s #> %s)", column.Name, q.Bind(column.Name, fullPath))
+ fullPath := column.jsonFullPath(extractValuesFromFields(simple.Restriction.Comparable.Member.Fields)...)
+ fieldToCheck := fmt.Sprintf("(%s #> %s)", column.Name, Bind(q, column.Name, fullPath))
return fmt.Sprintf("(%s IS NULL OR %s IN ('[]'::jsonb, 'null'::jsonb))", fieldToCheck, fieldToCheck), nil
case ColumnTypeArray:
return fmt.Sprintf("(%s IS NULL OR cardinality(%s) = 0)", column.Name, column.Name), nil
@@ -188,7 +190,7 @@
return q.expressionQuery(restriction.Arg.Composite, &compositeArgInfo{
comparator: restriction.Comparator,
column: column,
- fields: fieldsToStrings(restriction.Comparable.Member.Fields),
+ fields: extractValuesFromFields(restriction.Comparable.Member.Fields),
}, comparableOverride)
}
@@ -213,7 +215,7 @@
// as labels have multiple values and we basically check the existence of the specified value in the array.
// Currently leaving it as `=` to comply with the frontend.
if restriction.Comparator == "=" {
- value, err := q.jsonArrayHasArgValue(restriction.Arg, column, fieldsToStrings(restriction.Comparable.Member.Fields))
+ value, err := q.jsonArrayHasArgValue(restriction.Arg, column, extractValuesFromFields(restriction.Comparable.Member.Fields))
if err != nil {
return "", errors.Fmt("argument for field %q: %w", column.ExternalName, err)
}
@@ -314,7 +316,7 @@
if restriction.Comparator == "" {
if len(restriction.Comparable.Member.Fields) > 0 {
value := restriction.Comparable.Member.Value.Value
- fields := strings.Join(fieldsToStrings(restriction.Comparable.Member.Fields), ".")
+ fields := strings.Join(extractValuesFromFields(restriction.Comparable.Member.Fields), ".")
return fmt.Errorf("fields are not allowed without an operator, try wrapping %s.%s in double quotes: \"%s.%s\"", value, fields, value, fields)
}
return fmt.Errorf("global properties are not supported without a comparator. Got: %q", restriction.Comparable.Member.Value.Value)
@@ -406,7 +408,7 @@
return "", errors.New("fields not implemented yet")
}
- return q.Bind(columnName, comparable.Member.Value), nil
+ return Bind(q, columnName, comparable.Member.Value.Value), nil
}
// Checks whether value exist in the array specified in the json path
@@ -432,17 +434,13 @@
fullPath := column.jsonFullPath(fields...)
params := make([]string, len(fullPath))
for i, field := range fullPath {
- params[i] = q.Bind("json_path_variable_"+strconv.Itoa(i)+":"+column.Name, field)
+ params[i] = Bind(q, "json_path_variable_"+strconv.Itoa(i)+":"+column.Name, field)
}
- value := q.Bind(column.Name, comparable.Member.Value)
+ value := Bind(q, column.Name, comparable.Member.Value.Value)
return fmt.Sprintf("%s ? %s", strings.Join(params, " -> "), value), nil
}
-func fieldsToStrings(fields []*aip160.Value) []string {
- strings := make([]string, len(fields))
- for i, field := range fields {
- strings[i] = field.Value
- }
- return strings
+func extractValuesFromFields(fields []*aip160.Value) []string {
+ return utils.Map(fields, func(x *aip160.Value) string { return x.Value })
}
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go
index 74069de..73e0f94 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/filter_generator_test.go
@@ -7,7 +7,6 @@
import (
"testing"
- "go.chromium.org/luci/common/data/aip160"
"go.chromium.org/luci/common/testing/ftt"
"go.chromium.org/luci/common/testing/truth/assert"
"go.chromium.org/luci/common/testing/truth/should"
@@ -39,7 +38,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("dut_state = available", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "dut_state", Value: &aip160.Value{Value: "available"}},
+ {Name: "dut_state", Value: "available"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (dut_state = $1)"))
})
@@ -47,7 +46,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("dut_state != available", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "dut_state", Value: &aip160.Value{Value: "available"}},
+ {Name: "dut_state", Value: "available"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (dut_state <> $1)"))
})
@@ -55,7 +54,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("test_date > 2020-01-01", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+ {Name: "test_date", Value: "2020-01-01"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date > $1)"))
})
@@ -63,7 +62,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("test_date < 2020-01-01", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+ {Name: "test_date", Value: "2020-01-01"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date < $1)"))
})
@@ -71,7 +70,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("test_date >= 2020-01-01", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+ {Name: "test_date", Value: "2020-01-01"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date >= $1)"))
})
@@ -79,7 +78,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("test_date <= 2020-01-01", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "test_date", Value: &aip160.Value{Value: "2020-01-01"}},
+ {Name: "test_date", Value: "2020-01-01"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (test_date <= $1)"))
})
@@ -87,7 +86,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("host_group = \"acs:camerax\"", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "host_group", Value: &aip160.Value{Value: "acs:camerax", Quoted: true}},
+ {Name: "host_group", Value: "acs:camerax"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (host_group @> ARRAY[$1])"))
})
@@ -99,7 +98,7 @@
q, err := NewQueryBuilder(table).WithWhereClause("dut_state=(something)", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "dut_state", Value: &aip160.Value{Value: "something"}},
+ {Name: "dut_state", Value: "something"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (dut_state = $1)"))
})
@@ -124,7 +123,7 @@
{Name: "json_path_variable_2:labels", Value: string("values")},
{
Name: "labels",
- Value: &aip160.Value{Value: "OS_TYPE_CROS"},
+ Value: "OS_TYPE_CROS",
},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (labels -> $1 -> $2 -> $3 ? $4)"))
@@ -139,17 +138,17 @@
{Name: "json_path_variable_2:labels", Value: string("values")},
{
Name: "labels",
- Value: &aip160.Value{Value: "OS_TYPE_CROS"},
+ Value: "OS_TYPE_CROS",
},
{Name: "json_path_variable_0:labels", Value: "labels"},
{Name: "json_path_variable_1:labels", Value: "label-os_type"},
{Name: "json_path_variable_2:labels", Value: string("values")},
{
Name: "labels",
- Value: &aip160.Value{Value: "OS_TYPE_LABSTATION"},
+ Value: "OS_TYPE_LABSTATION",
},
- {Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_LEASED"}},
- {Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_AVAILABLE"}},
+ {Name: "dut_state", Value: "DEVICE_STATE_LEASED"},
+ {Name: "dut_state", Value: "DEVICE_STATE_AVAILABLE"},
// "labels", "label-os_type", "values", "OS_TYPE_CROS", "labels", "label-os_type", "values", "OS_TYPE_LABSTATION", "DEVICE_STATE_LEASED", "DEVICE_STATE_AVAILABLE",
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (((labels -> $1 -> $2 -> $3 ? $4) AND (labels -> $5 -> $6 -> $7 ? $8)) AND ((dut_state = $9) AND (dut_state = $10)))"))
@@ -159,9 +158,9 @@
q, err := NewQueryBuilder(table).WithWhereClause("dut_state = (a AND b OR c)", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "dut_state", Value: &aip160.Value{Value: "a"}},
- {Name: "dut_state", Value: &aip160.Value{Value: "b"}},
- {Name: "dut_state", Value: &aip160.Value{Value: "c"}},
+ {Name: "dut_state", Value: "a"},
+ {Name: "dut_state", Value: "b"},
+ {Name: "dut_state", Value: "c"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE ((dut_state = $1) AND ((dut_state = $2) OR (dut_state = $3)))"))
})
@@ -170,9 +169,9 @@
q, err := NewQueryBuilder(table).WithWhereClause("host_group = (a AND b OR c)", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "host_group", Value: &aip160.Value{Value: "a"}},
- {Name: "host_group", Value: &aip160.Value{Value: "b"}},
- {Name: "host_group", Value: &aip160.Value{Value: "c"}},
+ {Name: "host_group", Value: "a"},
+ {Name: "host_group", Value: "b"},
+ {Name: "host_group", Value: "c"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE ((host_group @> ARRAY[$1]) AND ((host_group @> ARRAY[$2]) OR (host_group @> ARRAY[$3])))"))
})
@@ -196,8 +195,8 @@
q, err := NewQueryBuilder(table).WithWhereClause("(NOT dut_state OR dut_state = (DEVICE_STATE_LEASED OR DEVICE_STATE_AVAILABLE))", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_LEASED"}},
- {Name: "dut_state", Value: &aip160.Value{Value: "DEVICE_STATE_AVAILABLE"}},
+ {Name: "dut_state", Value: "DEVICE_STATE_LEASED"},
+ {Name: "dut_state", Value: "DEVICE_STATE_AVAILABLE"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE ((dut_state = '' OR dut_state IS NULL) OR ((dut_state = $1) OR (dut_state = $2)))"))
})
@@ -207,7 +206,7 @@
q, err := NewQueryBuilder(table).SetSqlLangType(BigQueryLangType).WithWhereClause("dut_name:chromeos", nil)
assert.NoErr(t, err)
assert.Loosely(t, q.parameters.values, should.Match([]QueryParameter{
- {Name: "dut_name", Value: &aip160.Value{Value: "chromeos"}},
+ {Name: "dut_name", Value: "chromeos"},
}))
assert.Loosely(t, q.whereClause, should.Equal("WHERE (? IN UNNEST(dut_name))"))
})
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go b/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go
index 1492b8d..b400bfb 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/orderby_generator.go
@@ -81,7 +81,7 @@
fullPath := column.jsonFullPath(fields...)
params := make([]string, len(fullPath))
for i, field := range fullPath {
- params[i] = q.Bind(column.Name, field)
+ params[i] = Bind(q, column.Name, field)
}
// In case of arrays in json it will order by based on the first element
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/query.go b/go/src/infra/fleetconsole/internal/database/queryutils/query.go
index f462c50..cab8f12 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/query.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/query.go
@@ -11,6 +11,14 @@
"strings"
)
+// This is a type validator for values of the query param
+// and it helps to identify unsupported types in the compile time
+// ~X: This means any type whose underlying type is X
+// | is the OR operator
+type ParameterValueConstraint interface {
+ ~string | ~[]string
+}
+
type QueryParameter struct {
Name string
Value any
@@ -230,7 +238,7 @@
params := make([]string, len(values))
for i, value := range values {
- params[i] = q.Bind(column, value)
+ params[i] = Bind(q, column, value)
}
if len(params) > 0 {
@@ -253,7 +261,9 @@
// Bind binds a new query parameter with the given value, and returns
// the name of the parameter.
// The returned string is an injection-safe SQL expression.
-func (q *QueryBuilder) Bind(columnName string, value any) string {
+// Type constraint is mainly to identify not supported types passed
+// in compile time.
+func Bind[T ParameterValueConstraint](q *QueryBuilder, columnName string, value T) string {
var name string
if q.sqlLangType == BigQueryLangType {
name = "?"
diff --git a/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go b/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go
index a29168e..03303a1 100644
--- a/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go
+++ b/go/src/infra/fleetconsole/internal/database/queryutils/query_test.go
@@ -7,7 +7,6 @@
import (
"testing"
- "go.chromium.org/luci/common/data/aip160"
"go.chromium.org/luci/common/testing/ftt"
"go.chromium.org/luci/common/testing/truth/assert"
"go.chromium.org/luci/common/testing/truth/should"
@@ -41,7 +40,7 @@
q, err := builder.Build([]string{"my-realm"})
assert.NoErr(t, err)
assert.Loosely(t, q.Parameters, should.Match([]any{
- &aip160.Value{Value: "available"},
+ "available",
"my-realm",
}))
assert.Loosely(t, testutils.NormalizeSQL(q.Statement), should.Equal("SELECT id, dut_state, dut_name, labels, realm FROM \"Devices\" WHERE (dut_state = $1) AND (realm IN ($2) OR realm IS NULL) GROUP BY id ORDER BY id DESC LIMIT 10 OFFSET 20;"))
@@ -60,7 +59,7 @@
q, err := builder.Build([]string{"my-realm"})
assert.NoErr(t, err)
assert.Loosely(t, q.Parameters, should.Match([]any{
- &aip160.Value{Value: "available"},
+ "available",
"my-realm",
}))
assert.Loosely(t, testutils.NormalizeSQL(q.Statement), should.Equal("SELECT id, dut_state, dut_name, labels, realm FROM `Devices` WHERE (dut_state = ?) AND (realm IN (?) OR realm IS NULL) GROUP BY id ORDER BY id DESC LIMIT 10 OFFSET 20;"))