]> git.openstreetmap.org Git - nominatim.git/commitdiff
clean up ST_DWithin and intersects() functions
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 5 Dec 2023 17:02:40 +0000 (18:02 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Thu, 7 Dec 2023 08:31:00 +0000 (09:31 +0100)
A non-index version of ST_DWithin is not necessary. ST_Distance
can be used for that purpose. Index use for intersects can be
covered with a simple parameter.

nominatim/api/reverse.py
nominatim/api/search/db_searches.py
nominatim/db/sqlalchemy_types/geometry.py

index fb4c0b23d0f2fd4790d942b31508126f39a2d379..df5c10f2669951d0f0bbcf778815724c23a719ba 100644 (file)
@@ -180,7 +180,7 @@ class ReverseGeocoder:
         diststr = sa.text(f"{distance}")
 
         sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_from_placex(t)
-                .where(t.c.geometry.ST_DWithin(WKT_PARAM, diststr))
+                .where(t.c.geometry.within_distance(WKT_PARAM, diststr))
                 .where(t.c.indexed_status == 0)
                 .where(t.c.linked_place_id == None)
                 .where(sa.or_(sa.not_(t.c.geometry.is_area()),
@@ -219,7 +219,7 @@ class ReverseGeocoder:
         t = self.conn.t.placex
 
         sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_from_placex(t)
-                .where(t.c.geometry.ST_DWithin(WKT_PARAM, 0.001))
+                .where(t.c.geometry.within_distance(WKT_PARAM, 0.001))
                 .where(t.c.parent_place_id == parent_place_id)
                 .where(sa.func.IsAddressPoint(t))
                 .where(t.c.indexed_status == 0)
@@ -241,7 +241,7 @@ class ReverseGeocoder:
                    sa.select(t,
                              t.c.linegeo.ST_Distance(WKT_PARAM).label('distance'),
                              _locate_interpolation(t))
-                     .where(t.c.linegeo.ST_DWithin(WKT_PARAM, distance))
+                     .where(t.c.linegeo.within_distance(WKT_PARAM, distance))
                      .where(t.c.startnumber != None)
                      .order_by('distance')
                      .limit(1))
@@ -275,7 +275,7 @@ class ReverseGeocoder:
             inner = sa.select(t,
                               t.c.linegeo.ST_Distance(WKT_PARAM).label('distance'),
                               _locate_interpolation(t))\
-                      .where(t.c.linegeo.ST_DWithin(WKT_PARAM, 0.001))\
+                      .where(t.c.linegeo.within_distance(WKT_PARAM, 0.001))\
                       .where(t.c.parent_place_id == parent_place_id)\
                       .order_by('distance')\
                       .limit(1)\
index 2b4dfd3c9bdc5c78c7ccaa07bb1a18a4bef97d16..48bd6272c807de60edd454202fbab1d1540f9c09 100644 (file)
@@ -56,7 +56,7 @@ NEAR_RADIUS_PARAM: SaBind = sa.bindparam('near_radius')
 COUNTRIES_PARAM: SaBind = sa.bindparam('countries')
 
 def _within_near(t: SaFromClause) -> Callable[[], SaExpression]:
-    return lambda: t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)
+    return lambda: t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)
 
 def _exclude_places(t: SaFromClause) -> Callable[[], SaExpression]:
     return lambda: t.c.place_id.not_in(sa.bindparam('excluded'))
@@ -366,7 +366,7 @@ class PoiSearch(AbstractSearch):
                            .add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
                                          .label('importance'))\
                            .where(t.c.linked_place_id == None) \
-                           .where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
+                           .where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
                            .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
                            .limit(LIMIT_PARAM)
 
@@ -403,8 +403,8 @@ class PoiSearch(AbstractSearch):
 
                     if details.near and details.near_radius is not None:
                         sql = sql.order_by(table.c.centroid.ST_Distance(NEAR_PARAM))\
-                                 .where(table.c.centroid.ST_DWithin(NEAR_PARAM,
-                                                                    NEAR_RADIUS_PARAM))
+                                 .where(table.c.centroid.within_distance(NEAR_PARAM,
+                                                                         NEAR_RADIUS_PARAM))
 
                     if self.countries:
                         sql = sql.where(t.c.country_code.in_(self.countries.values))
@@ -632,11 +632,11 @@ class PlaceSearch(AbstractSearch):
             sql = sql.where(tsearch.c.address_rank > 9)
             tpc = conn.t.postcode
             pcs = self.postcodes.values
-            if self.expected_count > 1000:
+            if self.expected_count > 5000:
                 # Many results expected. Restrict by postcode.
                 sql = sql.where(sa.select(tpc.c.postcode)
                                   .where(tpc.c.postcode.in_(pcs))
-                                  .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12))
+                                  .where(tsearch.c.centroid.within_distance(tpc.c.geometry, 0.12))
                                   .exists())
 
             # Less results, only have a preference for close postcodes
@@ -648,27 +648,26 @@ class PlaceSearch(AbstractSearch):
 
         if details.viewbox is not None:
             if details.bounded_viewbox:
-                if details.viewbox.area < 0.2:
-                    sql = sql.where(tsearch.c.centroid.intersects(VIEWBOX_PARAM))
-                else:
-                    sql = sql.where(tsearch.c.centroid.ST_Intersects_no_index(VIEWBOX_PARAM))
+                sql = sql.where(tsearch.c.centroid
+                                         .intersects(VIEWBOX_PARAM,
+                                                     use_index=details.viewbox.area < 0.2))
             elif self.expected_count >= 10000:
-                if details.viewbox.area < 0.5:
-                    sql = sql.where(tsearch.c.centroid.intersects(VIEWBOX2_PARAM))
-                else:
-                    sql = sql.where(tsearch.c.centroid.ST_Intersects_no_index(VIEWBOX2_PARAM))
+                sql = sql.where(tsearch.c.centroid
+                                         .intersects(VIEWBOX2_PARAM,
+                                                     use_index=details.viewbox.area < 0.5))
             else:
-                penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
-                                   (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5),
+                penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM, use_index=False), 0.0),
+                                   (t.c.geometry.intersects(VIEWBOX2_PARAM, use_index=False), 0.5),
                                    else_=1.0)
 
         if details.near is not None:
             if details.near_radius is not None:
                 if details.near_radius < 0.1:
-                    sql = sql.where(tsearch.c.centroid.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
+                    sql = sql.where(tsearch.c.centroid.within_distance(NEAR_PARAM,
+                                                                       NEAR_RADIUS_PARAM))
                 else:
-                    sql = sql.where(tsearch.c.centroid.ST_DWithin_no_index(NEAR_PARAM,
-                                                                           NEAR_RADIUS_PARAM))
+                    sql = sql.where(tsearch.c.centroid
+                                             .ST_Distance(NEAR_PARAM) <  NEAR_RADIUS_PARAM)
             sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
                                       .label('importance'))
             sql = sql.order_by(sa.desc(sa.text('importance')))
index a36e8c462acfce3b4cc5e730b2eb5c008f1dfa14..4520fc8e53c84dc277f0ce143550361c0801288b 100644 (file)
@@ -165,7 +165,6 @@ def spatialite_dwithin_column(element: SaColumn,
               compiler.process(dist, **kw))
 
 
-
 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
     """ Simplified type decorator for PostGIS geometry. This type
         only supports geometries in 4326 projection.
@@ -206,7 +205,10 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
 
     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
 
-        def intersects(self, other: SaColumn) -> 'sa.Operators':
+        def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators':
+            if not use_index:
+                return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self.expr), other)
+
             if isinstance(self.expr, sa.Column):
                 return Geometry_ColumnIntersectsBbox(self.expr, other)
 
@@ -221,20 +223,11 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
             return Geometry_IsAreaLike(self)
 
 
-        def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
+        def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn:
             if isinstance(self.expr, sa.Column):
                 return Geometry_ColumnDWithin(self.expr, other, distance)
 
-            return sa.func.ST_DWithin(self.expr, other, distance)
-
-
-        def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
-            return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
-                                      other, distance)
-
-
-        def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
-            return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other)
+            return self.ST_Distance(other) < distance
 
 
         def ST_Distance(self, other: SaColumn) -> SaColumn:
@@ -313,18 +306,3 @@ def _add_function_alias(func: str, ftype: type, alias: str) -> None:
 
 for alias in SQLITE_FUNCTION_ALIAS:
     _add_function_alias(*alias)
-
-
-class ST_DWithin(sa.sql.functions.GenericFunction[Any]):
-    name = 'ST_DWithin'
-    inherit_cache = True
-
-
-@compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
-def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
-    geom1, geom2, dist = list(element.clauses)
-    return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % (
-        compiler.process(geom1, **kw), compiler.process(geom2, **kw),
-        compiler.process(dist, **kw),
-        compiler.process(geom1, **kw), compiler.process(geom2, **kw),
-        compiler.process(dist, **kw))