diff --git a/google/cloud/ndb/_gql.py b/google/cloud/ndb/_gql.py index 146c7b1c..dfc1c6d6 100644 --- a/google/cloud/ndb/_gql.py +++ b/google/cloud/ndb/_gql.py @@ -667,7 +667,9 @@ def _args_to_val(self, func, args): if func == "nop": return vals[0] # May be a Parameter pfunc = query_module.ParameterizedFunction(func, vals) - return pfunc + if pfunc.is_parameterized(): + return pfunc + return pfunc.resolve({}, {}) def query_filters(self, model_class, filters): """Get the filters in a format compatible with the Query constructor""" @@ -681,6 +683,8 @@ def query_filters(self, model_class, filters): val = self._args_to_val(func, args) if isinstance(val, query_module.ParameterizedThing): node = query_module.ParameterNode(prop, op, val) + elif op == "in": + node = prop._IN(val) else: node = prop._comparison(op, val) filters.append(node) @@ -762,3 +766,19 @@ def __eq__(self, other): def __repr__(self): return "Literal(%s)" % repr(self._value) + + +def _raise_not_implemented(func): + def raise_inner(value): + raise NotImplementedError( + "GQL function {} is not implemented".format(func) + ) + + return raise_inner + + +FUNCTIONS = { + "list": list, + "user": _raise_not_implemented("user"), + "key": _raise_not_implemented("key"), +} diff --git a/google/cloud/ndb/query.py b/google/cloud/ndb/query.py index 6716ae5e..17fc81bf 100644 --- a/google/cloud/ndb/query.py +++ b/google/cloud/ndb/query.py @@ -362,24 +362,38 @@ class ParameterizedFunction(ParameterizedThing): """ def __init__(self, func, values): - self.__func = func - self.__values = values + self.func = func + self.values = values + + from google.cloud.ndb import _gql # avoid circular import + + _func = _gql.FUNCTIONS.get(func) + if _func is None: + raise ValueError("Unknown GQL function: {}".format(func)) + self._func = _func def __repr__(self): - return "ParameterizedFunction(%r, %r)" % (self.__func, self.__values) + return "ParameterizedFunction(%r, %r)" % (self.func, self.values) def __eq__(self, other): if not isinstance(other, ParameterizedFunction): return NotImplemented - return self.__func == other.__func and self.__values == other.__values + return self.func == other.func and self.values == other.values - @property - def func(self): - return self.__func + def is_parameterized(self): + for value in self.values: + if isinstance(value, Parameter): + return True + return False - @property - def values(self): - return self.__values + def resolve(self, bindings, used): + values = [] + for value in self.values: + if isinstance(value, Parameter): + value = value.resolve(bindings, used) + values.append(value) + + return self._func(values) class Node(object): diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 541bd0b7..f97dd327 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -1421,3 +1421,27 @@ class SomeKind(ndb.Model): query = SomeKind.gql("WHERE foo = :1", 2) results = query.fetch() assert results[0].foo == 2 + + +@pytest.mark.usefixtures("client_context") +def test_IN(ds_entity): + for i in range(5): + entity_id = test_utils.system.unique_resource_id() + ds_entity(KIND, entity_id, foo=i) + + class SomeKind(ndb.Model): + foo = ndb.IntegerProperty() + + eventually(SomeKind.query().fetch, _length_equals(5)) + + query = SomeKind.gql("where foo in (2, 3)").order(SomeKind.foo) + results = query.fetch() + assert len(results) == 2 + assert results[0].foo == 2 + assert results[1].foo == 3 + + query = SomeKind.gql("where foo in :1", [2, 3]).order(SomeKind.foo) + results = query.fetch() + assert len(results) == 2 + assert results[0].foo == 2 + assert results[1].foo == 3 diff --git a/tests/unit/test__gql.py b/tests/unit/test__gql.py index d0045e3f..6620b1c9 100644 --- a/tests/unit/test__gql.py +++ b/tests/unit/test__gql.py @@ -18,6 +18,7 @@ from google.cloud.ndb import exceptions from google.cloud.ndb import model from google.cloud.ndb import _gql as gql_module +from google.cloud.ndb import query as query_module GQL_QUERY = """ @@ -329,12 +330,28 @@ class SomeKind(model.Model): @pytest.mark.usefixtures("in_context") def test_get_query_in(): class SomeKind(model.Model): - prop1 = model.StringProperty() + prop1 = model.IntegerProperty() gql = gql_module.GQL( "SELECT prop1 FROM SomeKind WHERE prop1 IN (1, 2, 3)" ) query = gql.get_query() + assert query.filters == query_module.OR( + query_module.FilterNode("prop1", "=", 1), + query_module.FilterNode("prop1", "=", 2), + query_module.FilterNode("prop1", "=", 3), + ) + + @staticmethod + @pytest.mark.usefixtures("in_context") + def test_get_query_in_parameterized(): + class SomeKind(model.Model): + prop1 = model.StringProperty() + + gql = gql_module.GQL( + "SELECT prop1 FROM SomeKind WHERE prop1 IN (:1, :2, :3)" + ) + query = gql.get_query() assert "'in'," in str(query.filters) @staticmethod @@ -346,3 +363,19 @@ class SomeKind(model.Model): gql = gql_module.GQL("SELECT __key__ FROM SomeKind WHERE prop1='a'") query = gql.get_query() assert query.default_options.keys_only is True + + +class TestFUNCTIONS: + @staticmethod + def test_list(): + assert gql_module.FUNCTIONS["list"]((1, 2)) == [1, 2] + + @staticmethod + def test_user(): + with pytest.raises(NotImplementedError): + gql_module.FUNCTIONS["user"]("any arg") + + @staticmethod + def test_key(): + with pytest.raises(NotImplementedError): + gql_module.FUNCTIONS["key"]("any arg") diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index d798e228..d358cc82 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -299,29 +299,34 @@ class TestParameterizedFunction: @staticmethod def test_constructor(): query = query_module.ParameterizedFunction( - "user", query_module.Parameter(1) + "user", [query_module.Parameter(1)] ) assert query.func == "user" - assert query.values == query_module.Parameter(1) + assert query.values == [query_module.Parameter(1)] + + @staticmethod + def test_constructor_bad_function(): + with pytest.raises(ValueError): + query_module.ParameterizedFunction("notafunc", ()) @staticmethod def test___repr__(): query = query_module.ParameterizedFunction( - "user", query_module.Parameter(1) + "user", [query_module.Parameter(1)] ) assert ( - query.__repr__() == "ParameterizedFunction('user', Parameter(1))" + query.__repr__() == "ParameterizedFunction('user', [Parameter(1)])" ) @staticmethod def test___eq__parameter(): query = query_module.ParameterizedFunction( - "user", query_module.Parameter(1) + "user", [query_module.Parameter(1)] ) assert ( query.__eq__( query_module.ParameterizedFunction( - "user", query_module.Parameter(1) + "user", [query_module.Parameter(1)] ) ) is True @@ -330,10 +335,37 @@ def test___eq__parameter(): @staticmethod def test___eq__no_parameter(): query = query_module.ParameterizedFunction( - "user", query_module.Parameter(1) + "user", [query_module.Parameter(1)] ) assert query.__eq__(42) is NotImplemented + @staticmethod + def test_is_parameterized_True(): + query = query_module.ParameterizedFunction( + "user", [query_module.Parameter(1)] + ) + assert query.is_parameterized() + + @staticmethod + def test_is_parameterized_False(): + query = query_module.ParameterizedFunction("user", [1]) + assert not query.is_parameterized() + + @staticmethod + def test_is_parameterized_no_arguments(): + query = query_module.ParameterizedFunction("user", ()) + assert not query.is_parameterized() + + @staticmethod + def test_resolve(): + query = query_module.ParameterizedFunction( + "list", [1, query_module.Parameter(2), query_module.Parameter(3)] + ) + used = {} + resolved = query.resolve({2: 4, 3: 6}, used) + assert resolved == [1, 4, 6] + assert used == {2: True, 3: True} + class TestNode: @staticmethod