diff --git a/amarillo/plugins/enhancer/services/gtfs_export.py b/amarillo/plugins/enhancer/services/gtfs_export.py index 1abd5c0..d9538da 100644 --- a/amarillo/plugins/enhancer/services/gtfs_export.py +++ b/amarillo/plugins/enhancer/services/gtfs_export.py @@ -8,11 +8,14 @@ import logging import re from amarillo.app.utils.utils import assert_folder_exists -from amarillo.app.models.gtfs import GtfsTimeDelta, GtfsFeedInfo, GtfsAgency, GtfsRoute, GtfsStop, GtfsStopTime, GtfsTrip, GtfsCalendar, GtfsCalendarDate, GtfsShape +from amarillo.app.models.gtfs import GtfsTimeDelta, GtfsFeedInfo, GtfsAgency, GtfsRoute, GtfsStop, GtfsTrip, GtfsCalendar, GtfsCalendarDate, GtfsShape, GtfsDriver, GtfsAdditionalRidesharingInfo +from amarillo.app.models.Carpool import Driver, RidesharingInfo from amarillo.app.services.stops import is_carpooling_stop from amarillo.app.services.gtfs_constants import * from amarillo.app.utils.utils import geodesic_distance_in_m +from amarillo.plugins.enhancer.services.trips import Trip + logger = logging.getLogger(__name__) @@ -24,7 +27,6 @@ class GtfsExport: stored_stops = {} - # TODO: add lists self.drivers and self.additional_ridesharing_infos def __init__(self, agencies, feed_info, ridestore, stopstore, bbox = None): self.stops = {} self.routes = [] @@ -34,6 +36,8 @@ class GtfsExport: self.stop_times = [] self.calendar = [] self.shapes = [] + self.drivers = {} #use a dictionary to avoid duplicate ids + self.additional_ridesharing_infos = [] self.agencies = agencies self.feed_info = feed_info self.localized_to = " nach " @@ -54,12 +58,14 @@ class GtfsExport: self._write_csvfile(gtfsfolder, 'stops.txt', self.stops.values()) self._write_csvfile(gtfsfolder, 'stop_times.txt', self.stop_times) self._write_csvfile(gtfsfolder, 'shapes.txt', self.shapes) - # TODO: write driver.txt and additional_ridesharing_info.txt + self._write_csvfile(gtfsfolder, 'driver.txt', self.drivers.values()) + self._write_csvfile(gtfsfolder, 'additional_ridesharing_info.txt', self.additional_ridesharing_infos) self._zip_files(gtfszip_filename, gtfsfolder) def _zip_files(self, gtfszip_filename, gtfsfolder): gtfsfiles = ['agency.txt', 'feed_info.txt', 'routes.txt', 'trips.txt', - 'calendar.txt', 'calendar_dates.txt', 'stops.txt', 'stop_times.txt', 'shapes.txt'] + 'calendar.txt', 'calendar_dates.txt', 'stops.txt', 'stop_times.txt', + 'shapes.txt', 'driver.txt', 'additional_ridesharing_info.txt'] with ZipFile(gtfszip_filename, 'w') as gtfszip: for gtfsfile in gtfsfiles: gtfszip.write(gtfsfolder+'/'+gtfsfile, gtfsfile) @@ -132,7 +138,7 @@ class GtfsExport: self.routes.append(self._create_route(agency, trip.route_id, trip.route_name)) - def _convert_trip(self, trip): + def _convert_trip(self, trip: Trip): self.trip_counter += 1 self.calendar.append(self._create_calendar(trip)) if not trip.runs_regularly: @@ -140,7 +146,26 @@ class GtfsExport: self.trips.append(self._create_trip(trip, self.trip_counter)) self._append_stops_and_stop_times(trip) self._append_shapes(trip, self.trip_counter) + + if(trip.driver is not None): + self.drivers[trip.driver.driver_id] = self._convert_driver(trip.driver) + if(trip.additional_ridesharing_info is not None): + self.additional_ridesharing_infos.append( + self._convert_additional_ridesharing_info(trip.trip_id, trip.additional_ridesharing_info)) + def _convert_driver(self, driver: Driver): + return GtfsDriver(driver.driver_id, driver.profile_picture, driver.rating) + + def _convert_additional_ridesharing_info(self, trip_id, info: RidesharingInfo): + # if we don't specify .value, the enum will appear in the export as e.g. LuggageSize.large + # and missing optional values get None + def get_enum_value(enum): + return enum.value if enum is not None else None + + return GtfsAdditionalRidesharingInfo( + trip_id, info.number_free_seats, get_enum_value(info.same_gender), get_enum_value(info.luggage_size), get_enum_value(info.animal_car), + info.car_model, info.car_brand, info.creation_date, get_enum_value(info.smoking), info.payment_method) + def _trip_headsign(self, destination): destination = destination.replace('(Deutschland)', '') destination = destination.replace(', Deutschland', '') @@ -185,8 +210,9 @@ class GtfsExport: def _create_calendar_date(self, trip): return GtfsCalendarDate(trip.trip_id, self._convert_stop_date(trip.start), CALENDAR_DATES_EXCEPTION_TYPE_ADDED) - def _create_trip(self, trip, shape_id): - return GtfsTrip(trip.route_id, trip.trip_id, trip.trip_id, shape_id, trip.trip_headsign, NO_BIKES_ALLOWED, trip.url) + def _create_trip(self, trip : Trip, shape_id): + driver_id = None if trip.driver is None else trip.driver.driver_id + return GtfsTrip(trip.route_id, trip.trip_id, driver_id, trip.trip_id, shape_id, trip.trip_headsign, NO_BIKES_ALLOWED, trip.url) def _convert_stop(self, stop): """ diff --git a/amarillo/plugins/enhancer/services/trips.py b/amarillo/plugins/enhancer/services/trips.py index 6fb58ff..4e50235 100644 --- a/amarillo/plugins/enhancer/services/trips.py +++ b/amarillo/plugins/enhancer/services/trips.py @@ -4,6 +4,7 @@ from amarillo.app.services.gtfs_constants import * from amarillo.app.services.routing import RoutingService, RoutingException from amarillo.app.services.stops import is_carpooling_stop from amarillo.app.utils.utils import assert_folder_exists, is_older_than_days, yesterday, geodesic_distance_in_m +from amarillo.app.models.Carpool import Driver, RidesharingInfo from shapely.geometry import Point, LineString, box from geojson_pydantic.geometries import LineString as GeoJSONLineString from datetime import datetime, timedelta @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) class Trip: # TODO: add driver attributes, additional ridesharing info - def __init__(self, trip_id, route_name, headsign, url, calendar, departureTime, path, agency, lastUpdated, stop_times, bbox): + def __init__(self, trip_id, route_name, headsign, url, calendar, departureTime, path, agency, lastUpdated, stop_times, driver: Driver, additional_ridesharing_info: RidesharingInfo, bbox): if isinstance(calendar, set): self.runs_regularly = True self.weekdays = [ @@ -43,6 +44,8 @@ class Trip: self.stops = [] self.lastUpdated = lastUpdated self.stop_times = stop_times + self.driver = driver + self.additional_ridesharing_info = additional_ridesharing_info self.bbox = bbox self.route_name = route_name self.trip_headsign = headsign @@ -203,7 +206,7 @@ class TripTransformer: def __init__(self, stops_store): self.stops_store = stops_store - def transform_to_trip(self, carpool): + def transform_to_trip(self, carpool : Carpool): stop_times = self._convert_stop_times(carpool) route_name = carpool.stops[0].name + " nach " + carpool.stops[-1].name headsign= carpool.stops[-1].name @@ -215,8 +218,7 @@ class TripTransformer: max([pt[0] for pt in path.coordinates]), max([pt[1] for pt in path.coordinates])) - # TODO: pass driver and ridesharing info object to the Trip constructor - trip = Trip(trip_id, route_name, headsign, str(carpool.deeplink), carpool.departureDate, carpool.departureTime, carpool.path, carpool.agency, carpool.lastUpdated, stop_times, bbox) + trip = Trip(trip_id, route_name, headsign, str(carpool.deeplink), carpool.departureDate, carpool.departureTime, carpool.path, carpool.agency, carpool.lastUpdated, stop_times, carpool.driver, carpool.additional_ridesharing_info, bbox) return trip