Repo created

This commit is contained in:
Fr4nz D13trich 2025-11-22 13:58:55 +01:00
parent 4af19165ec
commit 68073add76
12458 changed files with 12350765 additions and 2 deletions

View file

@ -0,0 +1,89 @@
"""Copyright (c) 2015, Emilia Petrisor
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."""
import numpy as np
def knot_interval(i_pts, alpha=0.5, closed=False):
if len(i_pts) < 4:
raise ValueError('CR-curves need at least 4 interpolatory points')
# i_pts is the list of interpolatory points P[0], P[1], ... P[n]
if closed:
i_pts += [i_pts[0], i_pts[1], i_pts[2]]
i_pts = np.array(i_pts)
dist = np.linalg.norm(i_pts[1:, :] - i_pts[:-1, :], axis=1)
return dist ** alpha
def ctrl_bezier(P, d):
# Associate to 4 consecutive interpolatory points and the corresponding three d-values,
# the Bezier control points
if len(P) != len(d) + 1 != 4:
raise ValueError('The list of points and knot intervals have inappropriate len ')
P = np.array(P)
bz = [0] * 4
bz[0] = P[1]
bz[1] = (d[0] ** 2 * P[2] - d[1] ** 2 * P[0] + (2 * d[0] ** 2 + 3 * d[0] * d[1] + d[1] ** 2) * P[1]) / (
3 * d[0] * (d[0] + d[1]))
bz[2] = (d[2] ** 2 * P[1] - d[1] ** 2 * P[3] + (2 * d[2] ** 2 + 3 * d[2] * d[1] + d[1] ** 2) * P[2]) / (
3 * d[2] * (d[1] + d[2]))
bz[3] = P[2]
return bz
def Bezier_curve(bz, nr=100):
# implements the de Casteljau algorithm to compute nr points on a Bezier curve
t = np.linspace(0, 1, nr)
N = len(bz)
points = [] # the list of points to be computed on the Bezier curve
for i in range(nr): # for each parameter t[i] evaluate a point on the Bezier curve
# via De Casteljau algorithm
aa = np.copy(bz)
for r in range(1, N):
aa[:N - r, :] = (1 - t[i]) * aa[:N - r, :] + t[i] * aa[1:N - r + 1, :] # convex combination
points.append(aa[0, :])
return points
def Catmull_Rom(i_pts, alpha=0.5, closed=False):
# returns the list of points computed on the interpolating CR curve
# i_pts the list of interpolatory points P[0], P[1], ...P[n]
curve_pts = [] # the list of all points to be computed on the CR curve
d = knot_interval(i_pts, alpha=alpha, closed=closed)
for k in range(len(i_pts) - 3):
cb = ctrl_bezier(i_pts[k:k + 4], d[k:k + 3])
curve_pts.extend(Bezier_curve(cb, nr=100))
return np.array(curve_pts)
def segment_to_Catmull_Rom_curve(p1, s1, s2, p2, nr=100, alpha=0.5):
i_pts = [p1, s1, s2, p2]
# returns the list of points computed on the interpolating CR curve
# i_pts the list of interpolatory points P[0], P[1], ...P[n]
curve_pts = [] # the list of all points to be computed on the CR curve
d = knot_interval(i_pts, alpha=alpha, closed=False)
cb = ctrl_bezier(i_pts, d)
curve_pts.extend(Bezier_curve(cb, nr=nr))
return curve_pts

View file

@ -0,0 +1,288 @@
#!/usr/bin/env python3
"""Parses GTFS feeds urls:
https://transit.land/ - Transitland
https://storage.googleapis.com/storage/v1/b/mdb-csv/o/sources.csv?alt=media
- Mobility Database (https://mobilitydata.org/)
Crawls all the urls, loads feed zips and extracts to the specified directory."""
import argparse
import concurrent.futures
import io
import json
import logging
import os
import csv
import time
import zipfile
import requests
MAX_RETRIES = 2
MAX_SLEEP_TIMEOUT_S = 30
RAW_FILE_MOBILITYDB = "raw_mobilitydb.csv"
URLS_FILE_TRANSITLAND = "feed_urls_transitland.txt"
URLS_FILE_MOBILITYDB = "feed_urls_mobilitydb.txt"
URL_MOBILITYDB_GTFS_SOURCE = "https://storage.googleapis.com/storage/v1/b/mdb-csv/o/sources.csv?alt=media"
THREADS_COUNT = 2
MAX_INDEX_LEN = 4
logger = logging.getLogger(__name__)
def download_gtfs_sources_mobilitydb(path):
"""Downloads the csv catalogue from Data Mobility"""
try:
req = requests.get(URL_MOBILITYDB_GTFS_SOURCE)
url_content = req.content
with open(os.path.join(path, RAW_FILE_MOBILITYDB), 'wb') as csv_file:
csv_file.write(url_content)
except requests.exceptions.HTTPError as http_err:
logger.error(
f"HTTP error {http_err} downloading zip from {URL_MOBILITYDB_GTFS_SOURCE}")
def get_gtfs_urls_mobilitydb(path, countries_list):
"""Extracts the feed urls from the downloaded csv file"""
download_from_all_countries = True
if countries_list:
download_from_all_countries = False
download_gtfs_sources_mobilitydb(path)
file = open(os.path.join(path, RAW_FILE_MOBILITYDB), encoding='UTF-8')
raw_sources = csv.DictReader(file)
next(raw_sources)
urls = [field["urls.direct_download"] for field in raw_sources if download_from_all_countries or field["location.country_code"] in countries_list]
write_list_to_file(os.path.join(path, URLS_FILE_MOBILITYDB), urls)
def get_feeds_links(data):
"""Extracts feed urls from the GTFS json description."""
gtfs_feeds_urls = []
for feed in data:
# Possible values: MDS, GBFS, GTFS_RT, GRFS
if feed["spec"].lower() != "gtfs":
continue
if "urls" in feed and feed["urls"] is not None and feed["urls"]:
gtfs_feeds_urls.append(feed["urls"]["static_current"])
return gtfs_feeds_urls
def parse_transitland_page(url):
"""Parses page with feeds list, extracts feeds urls and the next page url."""
retries = MAX_RETRIES
while retries > 0:
try:
response = requests.get(url)
response.raise_for_status()
data = json.loads(response.text)
if "feeds" in data:
gtfs_feeds_urls = get_feeds_links(data["feeds"])
else:
gtfs_feeds_urls = []
next_page = data["meta"]["next"] if "next" in data.get("meta", "") else ""
return gtfs_feeds_urls, next_page
except requests.exceptions.HTTPError as http_err:
logger.error(f"HTTP error {http_err} downloading zip from {url}")
if http_err == 429:
time.sleep(MAX_SLEEP_TIMEOUT_S)
except requests.exceptions.RequestException as ex:
logger.error(
f"Exception {ex} while parsing Transitland url {url} with code {response.status_code}"
)
retries -= 1
return [], ""
def extract_to_path(content, out_path):
"""Reads content as zip and extracts it to out_path."""
try:
archive = zipfile.ZipFile(io.BytesIO(content))
archive.extractall(path=out_path)
return True
except zipfile.BadZipfile:
logger.exception("BadZipfile exception.")
except Exception as e:
logger.exception(f"Exception while unzipping feed: {e}")
return False
def load_gtfs_feed_zip(path, url):
"""Downloads url-located zip and extracts it to path/index."""
retries = MAX_RETRIES
while retries > 0:
try:
response = requests.get(url, stream=True)
response.raise_for_status()
if not extract_to_path(response.content, path):
retries -= 1
logger.error(f"Could not extract zip: {url}")
continue
return True
except requests.exceptions.HTTPError as http_err:
logger.error(f"HTTP error {http_err} downloading zip from {url}")
except requests.exceptions.RequestException as ex:
logger.error(f"Exception {ex} downloading zip from {url}")
retries -= 1
return False
def write_list_to_file(path, lines):
"""Saves list of lines to path."""
with open(path, "w") as out:
out.write("\n".join(lines))
def crawl_transitland_for_feed_urls(out_path, transitland_api_key):
"""Crawls transitland feeds API and parses feeds urls from json on each page
Do not try to parallel it because of the Transitland HTTP requests restriction."""
start_page = "https://transit.land/api/v2/rest/feeds?api_key={}".format(transitland_api_key)
total_feeds = []
gtfs_feeds_urls, next_page = parse_transitland_page(start_page)
while next_page:
logger.info(f"Loaded {next_page}")
total_feeds += gtfs_feeds_urls
gtfs_feeds_urls, next_page = parse_transitland_page(next_page)
if gtfs_feeds_urls:
total_feeds += gtfs_feeds_urls
write_list_to_file(os.path.join(out_path, URLS_FILE_TRANSITLAND), total_feeds)
def get_filename(file_prefix, index):
return f"{file_prefix}_{index:0{MAX_INDEX_LEN}d}"
def load_gtfs_zips_from_urls(path, urls_file, threads_count, file_prefix):
"""Concurrently downloads feeds zips from urls to path."""
urls = [url.strip() for url in open(os.path.join(path, urls_file))]
if not urls:
logger.error(f"Empty urls from {path}")
return
logger.info(f"Preparing to load feeds: {len(urls)}")
err_count = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=threads_count) as executor:
future_to_url = {
executor.submit(
load_gtfs_feed_zip,
os.path.join(path, get_filename(file_prefix, i)),
url,
): url
for i, url in enumerate(urls)
}
for j, future in enumerate(
concurrent.futures.as_completed(future_to_url), start=1
):
url = future_to_url[future]
loaded = future.result()
if not loaded:
err_count += 1
logger.info(f"Handled {j}/{len(urls)} feed. Loaded = {loaded}. {url}")
logger.info(f"Done loading. {err_count}/{len(urls)} errors")
def main():
"""Downloads urls of feeds from feed aggregators and saves to the file.
Downloads feeds from these urls and saves to the directory."""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-p", "--path", required=True, help="working directory path")
parser.add_argument(
"-m", "--mode", required=True, help="fullrun | load_feed_urls | load_feed_zips"
)
parser.add_argument(
"-s",
"--source",
default="transitland",
help="source of feeds: transitland | mobilitydb | all",
)
parser.add_argument(
"-t",
"--threads",
type=int,
default=THREADS_COUNT,
help="threads count for loading zips",
)
# Required in order to use Transitlands api
parser.add_argument(
"-T",
"--transitland_api_key",
help="user key for working with transitland API v2"
)
# Example: to download data only for Germany and France use "--mdb_countries DE,FR"
parser.add_argument(
"-c",
"--mdb_countries",
help="use data from MobilityDatabase only from selected countries (use ISO codes)",
)
args = parser.parse_args()
logging.basicConfig(
filename=os.path.join(args.path, "crawling.log"),
filemode="w",
level=logging.INFO,
)
if args.mode in ["fullrun", "load_feed_urls"]:
if args.source in ["all", "mobilitydb"]:
mdb_countries = []
if args.mdb_countries:
mdb_countries = args.mdb_countries.split(',')
get_gtfs_urls_mobilitydb(args.path, mdb_countries)
if args.source in ["all", "transitland"]:
if not args.transitland_api_key:
logger.error(
"No key provided for Transit Land. Set transitland_api_key argument."
)
return
crawl_transitland_for_feed_urls(args.path, args.transitland_api_key)
if args.mode in ["fullrun", "load_feed_zips"]:
if args.source in ["all", "transitland"]:
load_gtfs_zips_from_urls(
args.path, URLS_FILE_TRANSITLAND, args.threads, "tl"
)
if args.source in ["all", "mobilitydb"]:
load_gtfs_zips_from_urls(args.path, URLS_FILE_MOBILITYDB, args.threads, "mdb")
if __name__ == "__main__":
main()

View file

@ -0,0 +1 @@
numpy

View file

@ -0,0 +1,128 @@
import math
def to_rgb(color_str):
if len(color_str) != 6:
return (0, 0, 0)
r = int(color_str[0:2], 16)
g = int(color_str[2:4], 16)
b = int(color_str[4:], 16)
return (r, g, b)
def blend_colors(rgb_array1, rgb_array2, k):
return (rgb_array1[0] * (1.0 - k) + rgb_array2[0] * k,
rgb_array1[1] * (1.0 - k) + rgb_array2[1] * k,
rgb_array1[2] * (1.0 - k) + rgb_array2[2] * k)
def rgb_pivot(n):
result = n / 12.92
if n > 0.04045:
result = ((n + 0.055) / 1.055) ** 2.4
return result * 100.0;
def to_xyz(rgb_array):
r = rgb_pivot(rgb_array[0] / 255.0);
g = rgb_pivot(rgb_array[1] / 255.0);
b = rgb_pivot(rgb_array[2] / 255.0);
return (r * 0.4124 + g * 0.3576 + b * 0.1805,
r * 0.2126 + g * 0.7152 + b * 0.0722,
r * 0.0193 + g * 0.1192 + b * 0.9505)
#https://en.wikipedia.org/wiki/Lab_color_space#CIELAB
def lab_pivot(n):
if n > 0.008856:
return n ** (1.0/3.0)
return (903.3 * n + 16.0) / 116.0
def to_lab(rgb_array):
xyz = to_xyz(rgb_array)
x = lab_pivot(xyz[0] / 95.047)
y = lab_pivot(xyz[1] / 100.0)
z = lab_pivot(xyz[2] / 108.883)
l = 116.0 * y - 16.0
if l < 0.0:
l = 0.0
a = 500.0 * (x - y)
b = 200.0 * (y - z)
return (l, a, b)
def lum_distance(ref_color, src_color):
return 30 * (ref_color[0] - src_color[0]) ** 2 +\
59 * (ref_color[1] - src_color[1]) ** 2 +\
11 * (ref_color[2] - src_color[2]) ** 2
def is_bluish(rgb_array):
d1 = lum_distance((255, 0, 0), rgb_array)
d2 = lum_distance((0, 0, 255), rgb_array)
return d2 < d1
#http://en.wikipedia.org/wiki/Color_difference#CIE94
def cie94(ref_color, src_color):
lab_ref = to_lab(ref_color)
lab_src = to_lab(src_color)
deltaL = lab_ref[0] - lab_src[0]
deltaA = lab_ref[1] - lab_src[1]
deltaB = lab_ref[2] - lab_src[2]
c1 = math.sqrt(lab_ref[0] * lab_ref[0] + lab_ref[1] * lab_ref[1])
c2 = math.sqrt(lab_src[0] * lab_src[0] + lab_src[1] * lab_src[1])
deltaC = c1 - c2
deltaH = deltaA * deltaA + deltaB * deltaB - deltaC * deltaC
if deltaH < 0.0:
deltaH = 0.0
else:
deltaH = math.sqrt(deltaH)
# cold tones if a color is more bluish.
Kl = 1.0
K1 = 0.045
K2 = 0.015
sc = 1.0 + K1 * c1
sh = 1.0 + K2 * c1
deltaLKlsl = deltaL / Kl
deltaCkcsc = deltaC / sc
deltaHkhsh = deltaH / sh
i = deltaLKlsl * deltaLKlsl + deltaCkcsc * deltaCkcsc + deltaHkhsh * deltaHkhsh
if i < 0:
return 0.0
return math.sqrt(i)
class Palette:
def __init__(self, colors):
self.colors = {}
for name, color_info in colors['colors'].items():
self.colors[name] = to_rgb(color_info['clear'])
def get_default_color(self):
return 'default'
def get_nearest_color(self, color_str, casing_color_str, excluded_names):
"""Returns the nearest color from the palette."""
nearest_color_name = self.get_default_color()
color = to_rgb(color_str)
if (casing_color_str is not None and len(casing_color_str) != 0):
color = blend_colors(color, to_rgb(casing_color_str), 0.5)
min_diff = None
bluish = is_bluish(color)
for name, palette_color in self.colors.items():
# Uncomment if you want to exclude duplicates.
#if name in excluded_names:
# continue
if bluish:
diff = lum_distance(palette_color, color)
else:
diff = cie94(palette_color, color)
if min_diff is None or diff < min_diff:
min_diff = diff
nearest_color_name = name
# Left here for debug purposes.
#print("Result: " + color_str + "," + str(casing_color_str) +
# " - " + nearest_color_name + ": bluish = " + str(bluish))
return nearest_color_name

View file

@ -0,0 +1,34 @@
#!/usr/bin/env python3
# It exports all transits colors to colors.txt file.
import argparse
import json
import os.path
if __name__ == '__main__':
parser = argparse.ArgumentParser()
default_colors_path = os.path.dirname(os.path.abspath(__file__)) + '/../../../data/colors.txt'
parser.add_argument('in_out_file', nargs='?', type=str, default=default_colors_path,
help='path to colors.txt file')
default_transits_colors_path = os.path.dirname(os.path.abspath(__file__)) + '/../../../data/transit_colors.txt'
parser.add_argument('-c', '--colors', nargs='?', type=str, default=default_transits_colors_path,
help='path to transit_colors.txt file')
args = parser.parse_args()
colors = set()
with open(args.in_out_file, 'r') as in_file:
lines = in_file.readlines()
for l in lines:
colors.add(int(l))
fields = ['clear', 'night', 'text', 'text_night']
with open(args.colors, 'r') as colors_file:
tr_colors = json.load(colors_file)
for name, color_info in tr_colors['colors'].items():
for field in fields:
if field in color_info:
colors.add(int(color_info[field], 16))
with open(args.in_out_file, 'w') as out_file:
for c in sorted(colors):
out_file.write(str(c) + os.linesep)

View file

@ -0,0 +1,429 @@
#!/usr/bin/env python3
# Generates transit graph for MWM transit section generator.
# Also shows preview of transit scheme lines.
import argparse
import copy
import json
import math
import numpy as np
import os.path
import bezier_curves
import transit_color_palette
class OsmIdCode:
NODE = 0x4000000000000000
WAY = 0x8000000000000000
RELATION = 0xC000000000000000
RESET = ~(NODE | WAY | RELATION)
TYPE2CODE = {
'n': NODE,
'r': RELATION,
'w': WAY
}
def get_extended_osm_id(osm_id, osm_type):
try:
return str(osm_id | OsmIdCode.TYPE2CODE[osm_type[0]])
except KeyError:
raise ValueError('Unknown OSM type: ' + osm_type)
def get_line_id(road_id, line_index):
return road_id << 4 | line_index
def get_interchange_node_id(min_stop_id):
return 1 << 62 | min_stop_id
def clamp(value, min_value, max_value):
return max(min(value, max_value), min_value)
def get_mercator_point(lat, lon):
lat = clamp(lat, -86.0, 86.0)
sin_x = math.sin(math.radians(lat))
y = math.degrees(0.5 * math.log((1.0 + sin_x) / (1.0 - sin_x)))
y = clamp(y, -180, 180)
return {'x': lon, 'y': y}
class TransitGraphBuilder:
def __init__(self, input_data, transit_colors, points_per_curve=100, alpha=0.5):
self.palette = transit_color_palette.Palette(transit_colors)
self.input_data = input_data
self.points_per_curve = points_per_curve
self.alpha = alpha
self.networks = []
self.lines = []
self.stops = {}
self.interchange_nodes = set()
self.transfers = {}
self.gates = {}
self.edges = []
self.segments = {}
self.shapes = []
self.transit_graph = None
self.matched_colors = {}
self.stop_names = {}
def __get_average_stops_point(self, stop_ids):
"""Returns an average position of the stops."""
count = len(stop_ids)
if count == 0:
raise ValueError('Average stops point calculation failed: the list of stop id is empty.')
average_point = [0, 0]
for stop_id in stop_ids:
point = self.__get_stop(stop_id)['point']
average_point[0] += point['x']
average_point[1] += point['y']
return [average_point[0] / count, average_point[1] / count]
def __add_gate(self, osm_id, is_entrance, is_exit, point, weight, stop_id):
"""Creates a new gate or adds information to the existing with the same weight."""
if (osm_id, weight) in self.gates:
gate_ref = self.gates[(osm_id, weight)]
if stop_id not in gate_ref['stop_ids']:
gate_ref['stop_ids'].append(stop_id)
gate_ref['entrance'] |= is_entrance
gate_ref['exit'] |= is_exit
return
gate = {'osm_id': osm_id,
'point': point,
'weight': weight,
'stop_ids': [stop_id],
'entrance': is_entrance,
'exit': is_exit
}
self.gates[(osm_id, weight)] = gate
def __get_interchange_node(self, stop_id):
"""Returns the existing interchange node or creates a new one."""
for node_stops in self.interchange_nodes:
if stop_id in node_stops:
return node_stops
return (stop_id,)
def __get_stop(self, stop_id):
"""Returns the stop or the interchange node."""
if stop_id in self.stops:
return self.stops[stop_id]
return self.transfers[stop_id]
def __check_line_title(self, line, route_name):
"""Formats correct line name."""
if line['title']:
return
name = route_name if route_name else line['number']
if len(line['stop_ids']) > 1:
first_stop = self.stop_names[line['stop_ids'][0]]
last_stop = self.stop_names[line['stop_ids'][-1]]
if first_stop and last_stop:
line['title'] = u'{0}: {1} - {2}'.format(name, first_stop, last_stop)
return
line['title'] = name
def __read_stops(self):
"""Reads stops, their exits and entrances."""
for stop_item in self.input_data['stops']:
stop = {}
stop['id'] = stop_item['id']
stop['osm_id'] = get_extended_osm_id(stop_item['osm_id'], stop_item['osm_type'])
if 'zone_id' in stop_item:
stop['zone_id'] = stop_item['zone_id']
stop['point'] = get_mercator_point(stop_item['lat'], stop_item['lon'])
stop['line_ids'] = []
# TODO: Save stop names stop_item['name'] and stop_item['int_name'] for text anchors calculation.
stop['title_anchors'] = []
self.stops[stop['id']] = stop
self.stop_names[stop['id']] = stop_item['name']
for entrance_item in stop_item['entrances']:
ex_id = get_extended_osm_id(entrance_item['osm_id'], entrance_item['osm_type'])
point = get_mercator_point(entrance_item['lat'], entrance_item['lon'])
self.__add_gate(ex_id, True, False, point, entrance_item['distance'], stop['id'])
for exit_item in stop_item['exits']:
ex_id = get_extended_osm_id(exit_item['osm_id'], exit_item['osm_type'])
point = get_mercator_point(exit_item['lat'], exit_item['lon'])
self.__add_gate(ex_id, False, True, point, exit_item['distance'], stop['id'])
def __read_transfers(self):
"""Reads transfers between stops."""
for transfer_item in self.input_data['transfers']:
edge = {'stop1_id': transfer_item[0],
'stop2_id': transfer_item[1],
'weight': transfer_item[2],
'transfer': True
}
self.edges.append(copy.deepcopy(edge))
edge['stop1_id'], edge['stop2_id'] = edge['stop2_id'], edge['stop1_id']
self.edges.append(edge)
def __read_networks(self):
"""Reads networks and routes."""
for network_item in self.input_data['networks']:
network_id = network_item['agency_id']
network = {'id': network_id,
'title': network_item['network']}
self.networks.append(network)
for route_item in network_item['routes']:
line_index = 0
# Create a line for each itinerary.
for line_item in route_item['itineraries']:
line_stops = line_item['stops']
line_id = get_line_id(route_item['route_id'], line_index)
line = {'id': line_id,
'title': line_item.get('name', ''),
'type': route_item['type'],
'network_id': network_id,
'number': route_item['ref'],
'interval': line_item['interval'],
'stop_ids': []
}
line['color'] = self.__match_color(route_item.get('colour', ''), route_item.get('casing', ''))
# TODO: Add processing of line_item['shape'] when this data will be available.
# TODO: Add processing of line_item['trip_ids'] when this data will be available.
# Create an edge for each connection of stops.
for i in range(len(line_stops)):
stop1 = line_stops[i]
line['stop_ids'].append(stop1[0])
self.stops[stop1[0]]['line_ids'].append(line_id)
if i + 1 < len(line_stops):
stop2 = line_stops[i + 1]
edge = {'stop1_id': stop1[0],
'stop2_id': stop2[0],
'weight': stop2[1] - stop1[1],
'transfer': False,
'line_id': line_id,
'shape_ids': []
}
self.edges.append(edge)
self.__check_line_title(line, route_item.get('name', ''))
self.lines.append(line)
line_index += 1
def __match_color(self, color_str, casing_str):
if color_str is None or len(color_str) == 0:
return self.palette.get_default_color()
if casing_str is None:
casing_str = ''
matched_colors_key = color_str + "/" + casing_str
if matched_colors_key in self.matched_colors:
return self.matched_colors[matched_colors_key]
c = self.palette.get_nearest_color(color_str, casing_str, self.matched_colors.values())
if c != self.palette.get_default_color():
self.matched_colors[matched_colors_key] = c
return c
def __generate_transfer_nodes(self):
"""Merges stops into transfer nodes."""
for edge in self.edges:
if edge['transfer']:
node1 = self.__get_interchange_node(edge['stop1_id'])
node2 = self.__get_interchange_node(edge['stop2_id'])
merged_node = tuple(sorted(set(node1 + node2)))
self.interchange_nodes.discard(node1)
self.interchange_nodes.discard(node2)
self.interchange_nodes.add(merged_node)
for node_stop_ids in self.interchange_nodes:
point = self.__get_average_stops_point(node_stop_ids)
transfer = {'id': get_interchange_node_id(self.stops[node_stop_ids[0]]['id']),
'stop_ids': list(node_stop_ids),
'point': {'x': point[0], 'y': point[1]},
'title_anchors': []
}
for stop_id in node_stop_ids:
self.stops[stop_id]['transfer_id'] = transfer['id']
self.transfers[transfer['id']] = transfer
def __collect_segments(self):
"""Prepares collection of segments for shapes generation."""
# Each line divided on segments by its stops and transfer nodes.
# Merge equal segments from different lines into a single one and collect adjacent stops of that segment.
# Average positions of these stops will be used as guide points for a curve generation.
for line in self.lines:
prev_seg = None
prev_id1 = None
for i in range(len(line['stop_ids']) - 1):
node1 = self.stops[line['stop_ids'][i]]
node2 = self.stops[line['stop_ids'][i + 1]]
id1 = node1.get('transfer_id', node1['id'])
id2 = node2.get('transfer_id', node2['id'])
if id1 == id2:
continue
seg = tuple(sorted([id1, id2]))
if seg not in self.segments:
self.segments[seg] = {'guide_points': {id1: set(), id2: set()}}
if prev_seg is not None:
self.segments[seg]['guide_points'][id1].add(prev_id1)
self.segments[prev_seg]['guide_points'][id1].add(id2)
prev_seg = seg
prev_id1 = id1
def __generate_shapes_for_segments(self):
"""Generates a curve for each connection of two stops / transfer nodes."""
for (id1, id2), info in self.segments.items():
point1 = [self.__get_stop(id1)['point']['x'], self.__get_stop(id1)['point']['y']]
point2 = [self.__get_stop(id2)['point']['x'], self.__get_stop(id2)['point']['y']]
if info['guide_points'][id1]:
guide1 = self.__get_average_stops_point(info['guide_points'][id1])
else:
guide1 = [2 * point1[0] - point2[0], 2 * point1[1] - point2[1]]
if info['guide_points'][id2]:
guide2 = self.__get_average_stops_point(info['guide_points'][id2])
else:
guide2 = [2 * point2[0] - point1[0], 2 * point2[1] - point1[1]]
curve_points = bezier_curves.segment_to_Catmull_Rom_curve(guide1, point1, point2, guide2,
self.points_per_curve, self.alpha)
info['curve'] = np.array(curve_points)
polyline = []
for point in curve_points:
polyline.append({'x': point[0], 'y': point[1]})
shape = {'id': {'stop1_id': id1, 'stop2_id': id2},
'polyline': polyline}
self.shapes.append(shape)
def __assign_shapes_to_edges(self):
"""Assigns a shape to each non-transfer edge."""
for edge in self.edges:
if not edge['transfer']:
stop1 = self.stops[edge['stop1_id']]
stop2 = self.stops[edge['stop2_id']]
id1 = stop1.get('transfer_id', stop1['id'])
id2 = stop2.get('transfer_id', stop2['id'])
seg = tuple(sorted([id1, id2]))
if seg in self.segments:
edge['shape_ids'].append({'stop1_id': seg[0], 'stop2_id': seg[1]})
def __create_scheme_shapes(self):
self.__collect_segments()
self.__generate_shapes_for_segments()
self.__assign_shapes_to_edges()
def build(self):
if self.transit_graph is not None:
return self.transit_graph
self.__read_stops()
self.__read_transfers()
self.__read_networks()
self.__generate_transfer_nodes()
self.__create_scheme_shapes()
self.transit_graph = {'networks': self.networks,
'lines': self.lines,
'gates': list(self.gates.values()),
'stops': list(self.stops.values()),
'transfers': list(self.transfers.values()),
'shapes': self.shapes,
'edges': self.edges}
return self.transit_graph
def show_preview(self):
import matplotlib.pyplot as plt
for (s1, s2), info in self.segments.items():
plt.plot(info['curve'][:, 0], info['curve'][:, 1], 'g')
for stop in self.stops.values():
if 'transfer_id' in stop:
point = self.transfers[stop['transfer_id']]['point']
size = 60
color = 'r'
else:
point = stop['point']
if len(stop['line_ids']) > 2:
size = 40
color = 'b'
else:
size = 20
color = 'g'
plt.scatter([point['x']], [point['y']], size, color)
plt.show()
def show_color_maching_table(self, title, colors_ref_table):
import matplotlib.pyplot as plt
import matplotlib.patches as patches
fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
plt.title(title)
sz = 1.0 / (2.0 * len(self.matched_colors))
delta_y = sz * 0.5
for c in self.matched_colors:
tokens = c.split('/')
if len(tokens[1]) == 0:
tokens[1] = tokens[0]
ax.add_patch(patches.Rectangle((sz, delta_y), sz, sz, facecolor="#" + tokens[0], edgecolor="#" + tokens[1]))
rect_title = tokens[0]
if tokens[0] != tokens[1]:
rect_title += "/" + tokens[1]
ax.text(2.5 * sz, delta_y, rect_title + " -> ")
ref_color = colors_ref_table[self.matched_colors[c]]
ax.add_patch(patches.Rectangle((0.3 + sz, delta_y), sz, sz, facecolor="#" + ref_color))
ax.text(0.3 + 2.5 * sz, delta_y, ref_color + " (" + self.matched_colors[c] + ")")
delta_y += sz * 2.0
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('input_file', help='input file name of transit data')
parser.add_argument('output_file', nargs='?', help='output file name of generated graph')
default_colors_path = os.path.dirname(os.path.abspath(__file__)) + '/../../../data/transit_colors.txt'
parser.add_argument('-c', '--colors', type=str, default=default_colors_path,
help='transit colors file COLORS_FILE_PATH', metavar='COLORS_FILE_PATH')
parser.add_argument('-p', '--preview', action="store_true", default=False,
help="show preview of the transit scheme")
parser.add_argument('-m', '--matched_colors', action="store_true", default=False,
help="show the matched colors table")
parser.add_argument('-a', '--alpha', type=float, default=0.5, help='the curves generator parameter value ALPHA',
metavar='ALPHA')
parser.add_argument('-n', '--num', type=int, default=100, help='the number NUM of points in a generated curve',
metavar='NUM')
args = parser.parse_args()
with open(args.input_file, 'r') as input_file:
data = json.load(input_file)
with open(args.colors, 'r') as colors_file:
colors = json.load(colors_file)
transit = TransitGraphBuilder(data, colors, args.num, args.alpha)
result = transit.build()
output_file = args.output_file
head, tail = os.path.split(os.path.abspath(args.input_file))
name, extension = os.path.splitext(tail)
if output_file is None:
output_file = os.path.join(head, name + '.transit' + extension)
with open(output_file, 'w') as json_file:
result_data = json.dumps(result, ensure_ascii=False, indent='\t', sort_keys=True, separators=(',', ':'))
json_file.write(result_data)
print('Transit graph generated:', output_file)
if args.preview:
transit.show_preview()
if args.matched_colors:
colors_ref_table = {}
for color_name, color_info in colors['colors'].items():
colors_ref_table[color_name] = color_info['clear']
transit.show_color_maching_table(name, colors_ref_table)