diff --git a/amarillo/plugins/enhancer/services/gtfs_export.py b/amarillo/plugins/enhancer/services/gtfs_export.py index 9d9aea5..d389d2f 100644 --- a/amarillo/plugins/enhancer/services/gtfs_export.py +++ b/amarillo/plugins/enhancer/services/gtfs_export.py @@ -11,6 +11,7 @@ from amarillo.app.utils.utils import assert_folder_exists from amarillo.app.models.gtfs import GtfsTimeDelta, GtfsFeedInfo, GtfsAgency, GtfsRoute, GtfsStop, GtfsStopTime, GtfsTrip, GtfsCalendar, GtfsCalendarDate, GtfsShape from amarillo.app.services.stops import is_carpooling_stop from amarillo.app.services.gtfs_constants import * +from amarillo.app.utils.utils import geodesic_distance_in_m logger = logging.getLogger(__name__) @@ -19,7 +20,7 @@ class GtfsExport: stops_counter = 0 trips_counter = 0 - routes_counter = 0 + trip_counter = 0 stored_stops = {} @@ -74,19 +75,59 @@ class GtfsExport: for stop in stopSet["stops"].itertuples(): self._load_stored_stop(stop) cloned_trips = dict(ridestore.trips) + groups, cloned_trips = self.group_trips_into_routes(cloned_trips) + for group in groups: + self.convert_route(group) for url, trip in cloned_trips.items(): if self.bbox is None or trip.intersects(self.bbox): self._convert_trip(trip) + + def group_trips_into_routes(self, trips: dict): + ungrouped_trips = dict(trips) + route_groups = list() + current_route_id = 1 + for trip_id, trip in trips.items(): + if len(ungrouped_trips) == 0: break + + #find trips whose start and end stops are within a specified distance of the current trip + group = {key: value for key, value in ungrouped_trips.items() if self.trips_are_close(trip, value)} + + route_groups.append(group) + for key, grouped_trip in group.items(): + grouped_trip.route_id = str(current_route_id) + ungrouped_trips.pop(key) + + current_route_id += 1 + + return route_groups, trips + def trips_are_close(self, trip1, trip2): + trip1_start = trip1.path.coordinates[0] + trip1_end = trip1.path.coordinates[-1] + + trip2_start = trip2.path.coordinates[0] + trip2_end = trip2.path.coordinates[-1] + + return self.within_range(trip1_start, trip2_start) and self.within_range(trip1_end, trip2_end) + + def within_range(self, stop1, stop2): + MERGE_RANGE_M = 500 + return geodesic_distance_in_m(stop1, stop2) <= MERGE_RANGE_M + + def convert_route(self, route_group): + # TODO receives a group of trips and turns it into a route object, + # handling cases where there are multiple agencies' routes grouped together + pass + def _convert_trip(self, trip): - self.routes_counter += 1 + self.trip_counter += 1 self.routes.append(self._create_route(trip)) self.calendar.append(self._create_calendar(trip)) if not trip.runs_regularly: self.calendar_dates.append(self._create_calendar_date(trip)) - self.trips.append(self._create_trip(trip, self.routes_counter)) + self.trips.append(self._create_trip(trip, self.trip_counter)) self._append_stops_and_stop_times(trip) - self._append_shapes(trip, self.routes_counter) + self._append_shapes(trip, self.trip_counter) def _trip_headsign(self, destination): destination = destination.replace('(Deutschland)', '') @@ -133,7 +174,7 @@ class GtfsExport: return GtfsCalendarDate(trip.trip_id, self._convert_stop_date(trip.start), CALENDAR_DATES_EXCEPTION_TYPE_ADDED) def _create_trip(self, trip, shape_id): - return GtfsTrip(trip.trip_id, trip.trip_id, trip.trip_id, shape_id, trip.trip_headsign, NO_BIKES_ALLOWED) + return GtfsTrip(trip.route_id, trip.trip_id, trip.trip_id, shape_id, trip.trip_headsign, NO_BIKES_ALLOWED) def _convert_stop(self, stop): """