Repo created
This commit is contained in:
parent
4af19165ec
commit
68073add76
12458 changed files with 12350765 additions and 2 deletions
89
tools/python/transit/bezier_curves.py
Normal file
89
tools/python/transit/bezier_curves.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Copyright (c) 2015, Emilia Petrisor
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def knot_interval(i_pts, alpha=0.5, closed=False):
|
||||
if len(i_pts) < 4:
|
||||
raise ValueError('CR-curves need at least 4 interpolatory points')
|
||||
# i_pts is the list of interpolatory points P[0], P[1], ... P[n]
|
||||
if closed:
|
||||
i_pts += [i_pts[0], i_pts[1], i_pts[2]]
|
||||
i_pts = np.array(i_pts)
|
||||
dist = np.linalg.norm(i_pts[1:, :] - i_pts[:-1, :], axis=1)
|
||||
return dist ** alpha
|
||||
|
||||
|
||||
def ctrl_bezier(P, d):
|
||||
# Associate to 4 consecutive interpolatory points and the corresponding three d-values,
|
||||
# the Bezier control points
|
||||
if len(P) != len(d) + 1 != 4:
|
||||
raise ValueError('The list of points and knot intervals have inappropriate len ')
|
||||
P = np.array(P)
|
||||
bz = [0] * 4
|
||||
bz[0] = P[1]
|
||||
bz[1] = (d[0] ** 2 * P[2] - d[1] ** 2 * P[0] + (2 * d[0] ** 2 + 3 * d[0] * d[1] + d[1] ** 2) * P[1]) / (
|
||||
3 * d[0] * (d[0] + d[1]))
|
||||
bz[2] = (d[2] ** 2 * P[1] - d[1] ** 2 * P[3] + (2 * d[2] ** 2 + 3 * d[2] * d[1] + d[1] ** 2) * P[2]) / (
|
||||
3 * d[2] * (d[1] + d[2]))
|
||||
bz[3] = P[2]
|
||||
return bz
|
||||
|
||||
|
||||
def Bezier_curve(bz, nr=100):
|
||||
# implements the de Casteljau algorithm to compute nr points on a Bezier curve
|
||||
t = np.linspace(0, 1, nr)
|
||||
N = len(bz)
|
||||
points = [] # the list of points to be computed on the Bezier curve
|
||||
for i in range(nr): # for each parameter t[i] evaluate a point on the Bezier curve
|
||||
# via De Casteljau algorithm
|
||||
aa = np.copy(bz)
|
||||
for r in range(1, N):
|
||||
aa[:N - r, :] = (1 - t[i]) * aa[:N - r, :] + t[i] * aa[1:N - r + 1, :] # convex combination
|
||||
points.append(aa[0, :])
|
||||
return points
|
||||
|
||||
|
||||
def Catmull_Rom(i_pts, alpha=0.5, closed=False):
|
||||
# returns the list of points computed on the interpolating CR curve
|
||||
# i_pts the list of interpolatory points P[0], P[1], ...P[n]
|
||||
curve_pts = [] # the list of all points to be computed on the CR curve
|
||||
d = knot_interval(i_pts, alpha=alpha, closed=closed)
|
||||
for k in range(len(i_pts) - 3):
|
||||
cb = ctrl_bezier(i_pts[k:k + 4], d[k:k + 3])
|
||||
curve_pts.extend(Bezier_curve(cb, nr=100))
|
||||
|
||||
return np.array(curve_pts)
|
||||
|
||||
|
||||
def segment_to_Catmull_Rom_curve(p1, s1, s2, p2, nr=100, alpha=0.5):
|
||||
i_pts = [p1, s1, s2, p2]
|
||||
# returns the list of points computed on the interpolating CR curve
|
||||
# i_pts the list of interpolatory points P[0], P[1], ...P[n]
|
||||
curve_pts = [] # the list of all points to be computed on the CR curve
|
||||
d = knot_interval(i_pts, alpha=alpha, closed=False)
|
||||
cb = ctrl_bezier(i_pts, d)
|
||||
curve_pts.extend(Bezier_curve(cb, nr=nr))
|
||||
return curve_pts
|
||||
288
tools/python/transit/gtfs/download_gtfs.py
Normal file
288
tools/python/transit/gtfs/download_gtfs.py
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Parses GTFS feeds urls:
|
||||
https://transit.land/ - Transitland
|
||||
https://storage.googleapis.com/storage/v1/b/mdb-csv/o/sources.csv?alt=media
|
||||
- Mobility Database (https://mobilitydata.org/)
|
||||
Crawls all the urls, loads feed zips and extracts to the specified directory."""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import csv
|
||||
import time
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
|
||||
MAX_RETRIES = 2
|
||||
MAX_SLEEP_TIMEOUT_S = 30
|
||||
|
||||
RAW_FILE_MOBILITYDB = "raw_mobilitydb.csv"
|
||||
|
||||
URLS_FILE_TRANSITLAND = "feed_urls_transitland.txt"
|
||||
URLS_FILE_MOBILITYDB = "feed_urls_mobilitydb.txt"
|
||||
|
||||
URL_MOBILITYDB_GTFS_SOURCE = "https://storage.googleapis.com/storage/v1/b/mdb-csv/o/sources.csv?alt=media"
|
||||
|
||||
THREADS_COUNT = 2
|
||||
MAX_INDEX_LEN = 4
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_gtfs_sources_mobilitydb(path):
|
||||
"""Downloads the csv catalogue from Data Mobility"""
|
||||
try:
|
||||
req = requests.get(URL_MOBILITYDB_GTFS_SOURCE)
|
||||
url_content = req.content
|
||||
with open(os.path.join(path, RAW_FILE_MOBILITYDB), 'wb') as csv_file:
|
||||
csv_file.write(url_content)
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
logger.error(
|
||||
f"HTTP error {http_err} downloading zip from {URL_MOBILITYDB_GTFS_SOURCE}")
|
||||
|
||||
|
||||
def get_gtfs_urls_mobilitydb(path, countries_list):
|
||||
"""Extracts the feed urls from the downloaded csv file"""
|
||||
download_from_all_countries = True
|
||||
if countries_list:
|
||||
download_from_all_countries = False
|
||||
|
||||
download_gtfs_sources_mobilitydb(path)
|
||||
file = open(os.path.join(path, RAW_FILE_MOBILITYDB), encoding='UTF-8')
|
||||
raw_sources = csv.DictReader(file)
|
||||
next(raw_sources)
|
||||
urls = [field["urls.direct_download"] for field in raw_sources if download_from_all_countries or field["location.country_code"] in countries_list]
|
||||
write_list_to_file(os.path.join(path, URLS_FILE_MOBILITYDB), urls)
|
||||
|
||||
|
||||
def get_feeds_links(data):
|
||||
"""Extracts feed urls from the GTFS json description."""
|
||||
gtfs_feeds_urls = []
|
||||
|
||||
for feed in data:
|
||||
# Possible values: MDS, GBFS, GTFS_RT, GRFS
|
||||
if feed["spec"].lower() != "gtfs":
|
||||
continue
|
||||
|
||||
if "urls" in feed and feed["urls"] is not None and feed["urls"]:
|
||||
gtfs_feeds_urls.append(feed["urls"]["static_current"])
|
||||
|
||||
return gtfs_feeds_urls
|
||||
|
||||
|
||||
def parse_transitland_page(url):
|
||||
"""Parses page with feeds list, extracts feeds urls and the next page url."""
|
||||
retries = MAX_RETRIES
|
||||
|
||||
while retries > 0:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
data = json.loads(response.text)
|
||||
if "feeds" in data:
|
||||
gtfs_feeds_urls = get_feeds_links(data["feeds"])
|
||||
else:
|
||||
gtfs_feeds_urls = []
|
||||
|
||||
next_page = data["meta"]["next"] if "next" in data.get("meta", "") else ""
|
||||
return gtfs_feeds_urls, next_page
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
logger.error(f"HTTP error {http_err} downloading zip from {url}")
|
||||
if http_err == 429:
|
||||
time.sleep(MAX_SLEEP_TIMEOUT_S)
|
||||
except requests.exceptions.RequestException as ex:
|
||||
logger.error(
|
||||
f"Exception {ex} while parsing Transitland url {url} with code {response.status_code}"
|
||||
)
|
||||
|
||||
retries -= 1
|
||||
|
||||
return [], ""
|
||||
|
||||
|
||||
def extract_to_path(content, out_path):
|
||||
"""Reads content as zip and extracts it to out_path."""
|
||||
try:
|
||||
archive = zipfile.ZipFile(io.BytesIO(content))
|
||||
archive.extractall(path=out_path)
|
||||
return True
|
||||
except zipfile.BadZipfile:
|
||||
logger.exception("BadZipfile exception.")
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception while unzipping feed: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def load_gtfs_feed_zip(path, url):
|
||||
"""Downloads url-located zip and extracts it to path/index."""
|
||||
retries = MAX_RETRIES
|
||||
while retries > 0:
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
if not extract_to_path(response.content, path):
|
||||
retries -= 1
|
||||
logger.error(f"Could not extract zip: {url}")
|
||||
continue
|
||||
|
||||
return True
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
logger.error(f"HTTP error {http_err} downloading zip from {url}")
|
||||
except requests.exceptions.RequestException as ex:
|
||||
logger.error(f"Exception {ex} downloading zip from {url}")
|
||||
retries -= 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def write_list_to_file(path, lines):
|
||||
"""Saves list of lines to path."""
|
||||
with open(path, "w") as out:
|
||||
out.write("\n".join(lines))
|
||||
|
||||
|
||||
def crawl_transitland_for_feed_urls(out_path, transitland_api_key):
|
||||
"""Crawls transitland feeds API and parses feeds urls from json on each page
|
||||
Do not try to parallel it because of the Transitland HTTP requests restriction."""
|
||||
start_page = "https://transit.land/api/v2/rest/feeds?api_key={}".format(transitland_api_key)
|
||||
|
||||
total_feeds = []
|
||||
gtfs_feeds_urls, next_page = parse_transitland_page(start_page)
|
||||
|
||||
while next_page:
|
||||
logger.info(f"Loaded {next_page}")
|
||||
total_feeds += gtfs_feeds_urls
|
||||
gtfs_feeds_urls, next_page = parse_transitland_page(next_page)
|
||||
|
||||
if gtfs_feeds_urls:
|
||||
total_feeds += gtfs_feeds_urls
|
||||
|
||||
write_list_to_file(os.path.join(out_path, URLS_FILE_TRANSITLAND), total_feeds)
|
||||
|
||||
|
||||
def get_filename(file_prefix, index):
|
||||
return f"{file_prefix}_{index:0{MAX_INDEX_LEN}d}"
|
||||
|
||||
|
||||
def load_gtfs_zips_from_urls(path, urls_file, threads_count, file_prefix):
|
||||
"""Concurrently downloads feeds zips from urls to path."""
|
||||
urls = [url.strip() for url in open(os.path.join(path, urls_file))]
|
||||
if not urls:
|
||||
logger.error(f"Empty urls from {path}")
|
||||
return
|
||||
logger.info(f"Preparing to load feeds: {len(urls)}")
|
||||
err_count = 0
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=threads_count) as executor:
|
||||
future_to_url = {
|
||||
executor.submit(
|
||||
load_gtfs_feed_zip,
|
||||
os.path.join(path, get_filename(file_prefix, i)),
|
||||
url,
|
||||
): url
|
||||
for i, url in enumerate(urls)
|
||||
}
|
||||
|
||||
for j, future in enumerate(
|
||||
concurrent.futures.as_completed(future_to_url), start=1
|
||||
):
|
||||
url = future_to_url[future]
|
||||
|
||||
loaded = future.result()
|
||||
if not loaded:
|
||||
err_count += 1
|
||||
logger.info(f"Handled {j}/{len(urls)} feed. Loaded = {loaded}. {url}")
|
||||
|
||||
logger.info(f"Done loading. {err_count}/{len(urls)} errors")
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
"""Downloads urls of feeds from feed aggregators and saves to the file.
|
||||
Downloads feeds from these urls and saves to the directory."""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("-p", "--path", required=True, help="working directory path")
|
||||
|
||||
parser.add_argument(
|
||||
"-m", "--mode", required=True, help="fullrun | load_feed_urls | load_feed_zips"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--source",
|
||||
default="transitland",
|
||||
help="source of feeds: transitland | mobilitydb | all",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--threads",
|
||||
type=int,
|
||||
default=THREADS_COUNT,
|
||||
help="threads count for loading zips",
|
||||
)
|
||||
|
||||
# Required in order to use Transitlands api
|
||||
parser.add_argument(
|
||||
"-T",
|
||||
"--transitland_api_key",
|
||||
help="user key for working with transitland API v2"
|
||||
)
|
||||
|
||||
# Example: to download data only for Germany and France use "--mdb_countries DE,FR"
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--mdb_countries",
|
||||
help="use data from MobilityDatabase only from selected countries (use ISO codes)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
filename=os.path.join(args.path, "crawling.log"),
|
||||
filemode="w",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
if args.mode in ["fullrun", "load_feed_urls"]:
|
||||
|
||||
if args.source in ["all", "mobilitydb"]:
|
||||
mdb_countries = []
|
||||
if args.mdb_countries:
|
||||
mdb_countries = args.mdb_countries.split(',')
|
||||
|
||||
get_gtfs_urls_mobilitydb(args.path, mdb_countries)
|
||||
if args.source in ["all", "transitland"]:
|
||||
if not args.transitland_api_key:
|
||||
logger.error(
|
||||
"No key provided for Transit Land. Set transitland_api_key argument."
|
||||
)
|
||||
return
|
||||
crawl_transitland_for_feed_urls(args.path, args.transitland_api_key)
|
||||
|
||||
if args.mode in ["fullrun", "load_feed_zips"]:
|
||||
|
||||
if args.source in ["all", "transitland"]:
|
||||
load_gtfs_zips_from_urls(
|
||||
args.path, URLS_FILE_TRANSITLAND, args.threads, "tl"
|
||||
)
|
||||
if args.source in ["all", "mobilitydb"]:
|
||||
load_gtfs_zips_from_urls(args.path, URLS_FILE_MOBILITYDB, args.threads, "mdb")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
tools/python/transit/requirements.txt
Normal file
1
tools/python/transit/requirements.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
numpy
|
||||
128
tools/python/transit/transit_color_palette.py
Normal file
128
tools/python/transit/transit_color_palette.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
import math
|
||||
|
||||
def to_rgb(color_str):
|
||||
if len(color_str) != 6:
|
||||
return (0, 0, 0)
|
||||
r = int(color_str[0:2], 16)
|
||||
g = int(color_str[2:4], 16)
|
||||
b = int(color_str[4:], 16)
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
def blend_colors(rgb_array1, rgb_array2, k):
|
||||
return (rgb_array1[0] * (1.0 - k) + rgb_array2[0] * k,
|
||||
rgb_array1[1] * (1.0 - k) + rgb_array2[1] * k,
|
||||
rgb_array1[2] * (1.0 - k) + rgb_array2[2] * k)
|
||||
|
||||
|
||||
def rgb_pivot(n):
|
||||
result = n / 12.92
|
||||
if n > 0.04045:
|
||||
result = ((n + 0.055) / 1.055) ** 2.4
|
||||
return result * 100.0;
|
||||
|
||||
|
||||
def to_xyz(rgb_array):
|
||||
r = rgb_pivot(rgb_array[0] / 255.0);
|
||||
g = rgb_pivot(rgb_array[1] / 255.0);
|
||||
b = rgb_pivot(rgb_array[2] / 255.0);
|
||||
return (r * 0.4124 + g * 0.3576 + b * 0.1805,
|
||||
r * 0.2126 + g * 0.7152 + b * 0.0722,
|
||||
r * 0.0193 + g * 0.1192 + b * 0.9505)
|
||||
|
||||
|
||||
#https://en.wikipedia.org/wiki/Lab_color_space#CIELAB
|
||||
def lab_pivot(n):
|
||||
if n > 0.008856:
|
||||
return n ** (1.0/3.0)
|
||||
return (903.3 * n + 16.0) / 116.0
|
||||
|
||||
|
||||
def to_lab(rgb_array):
|
||||
xyz = to_xyz(rgb_array)
|
||||
x = lab_pivot(xyz[0] / 95.047)
|
||||
y = lab_pivot(xyz[1] / 100.0)
|
||||
z = lab_pivot(xyz[2] / 108.883)
|
||||
l = 116.0 * y - 16.0
|
||||
if l < 0.0:
|
||||
l = 0.0
|
||||
a = 500.0 * (x - y)
|
||||
b = 200.0 * (y - z)
|
||||
return (l, a, b)
|
||||
|
||||
|
||||
def lum_distance(ref_color, src_color):
|
||||
return 30 * (ref_color[0] - src_color[0]) ** 2 +\
|
||||
59 * (ref_color[1] - src_color[1]) ** 2 +\
|
||||
11 * (ref_color[2] - src_color[2]) ** 2
|
||||
|
||||
|
||||
def is_bluish(rgb_array):
|
||||
d1 = lum_distance((255, 0, 0), rgb_array)
|
||||
d2 = lum_distance((0, 0, 255), rgb_array)
|
||||
return d2 < d1
|
||||
|
||||
|
||||
#http://en.wikipedia.org/wiki/Color_difference#CIE94
|
||||
def cie94(ref_color, src_color):
|
||||
lab_ref = to_lab(ref_color)
|
||||
lab_src = to_lab(src_color)
|
||||
deltaL = lab_ref[0] - lab_src[0]
|
||||
deltaA = lab_ref[1] - lab_src[1]
|
||||
deltaB = lab_ref[2] - lab_src[2]
|
||||
c1 = math.sqrt(lab_ref[0] * lab_ref[0] + lab_ref[1] * lab_ref[1])
|
||||
c2 = math.sqrt(lab_src[0] * lab_src[0] + lab_src[1] * lab_src[1])
|
||||
deltaC = c1 - c2
|
||||
deltaH = deltaA * deltaA + deltaB * deltaB - deltaC * deltaC
|
||||
if deltaH < 0.0:
|
||||
deltaH = 0.0
|
||||
else:
|
||||
deltaH = math.sqrt(deltaH)
|
||||
# cold tones if a color is more bluish.
|
||||
Kl = 1.0
|
||||
K1 = 0.045
|
||||
K2 = 0.015
|
||||
sc = 1.0 + K1 * c1
|
||||
sh = 1.0 + K2 * c1
|
||||
deltaLKlsl = deltaL / Kl
|
||||
deltaCkcsc = deltaC / sc
|
||||
deltaHkhsh = deltaH / sh
|
||||
i = deltaLKlsl * deltaLKlsl + deltaCkcsc * deltaCkcsc + deltaHkhsh * deltaHkhsh
|
||||
if i < 0:
|
||||
return 0.0
|
||||
return math.sqrt(i)
|
||||
|
||||
|
||||
class Palette:
|
||||
def __init__(self, colors):
|
||||
self.colors = {}
|
||||
for name, color_info in colors['colors'].items():
|
||||
self.colors[name] = to_rgb(color_info['clear'])
|
||||
|
||||
def get_default_color(self):
|
||||
return 'default'
|
||||
|
||||
def get_nearest_color(self, color_str, casing_color_str, excluded_names):
|
||||
"""Returns the nearest color from the palette."""
|
||||
nearest_color_name = self.get_default_color()
|
||||
color = to_rgb(color_str)
|
||||
if (casing_color_str is not None and len(casing_color_str) != 0):
|
||||
color = blend_colors(color, to_rgb(casing_color_str), 0.5)
|
||||
min_diff = None
|
||||
|
||||
bluish = is_bluish(color)
|
||||
for name, palette_color in self.colors.items():
|
||||
# Uncomment if you want to exclude duplicates.
|
||||
#if name in excluded_names:
|
||||
# continue
|
||||
if bluish:
|
||||
diff = lum_distance(palette_color, color)
|
||||
else:
|
||||
diff = cie94(palette_color, color)
|
||||
if min_diff is None or diff < min_diff:
|
||||
min_diff = diff
|
||||
nearest_color_name = name
|
||||
# Left here for debug purposes.
|
||||
#print("Result: " + color_str + "," + str(casing_color_str) +
|
||||
# " - " + nearest_color_name + ": bluish = " + str(bluish))
|
||||
return nearest_color_name
|
||||
34
tools/python/transit/transit_colors_export.py
Executable file
34
tools/python/transit/transit_colors_export.py
Executable file
|
|
@ -0,0 +1,34 @@
|
|||
#!/usr/bin/env python3
|
||||
# It exports all transits colors to colors.txt file.
|
||||
import argparse
|
||||
import json
|
||||
import os.path
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
default_colors_path = os.path.dirname(os.path.abspath(__file__)) + '/../../../data/colors.txt'
|
||||
parser.add_argument('in_out_file', nargs='?', type=str, default=default_colors_path,
|
||||
help='path to colors.txt file')
|
||||
default_transits_colors_path = os.path.dirname(os.path.abspath(__file__)) + '/../../../data/transit_colors.txt'
|
||||
parser.add_argument('-c', '--colors', nargs='?', type=str, default=default_transits_colors_path,
|
||||
help='path to transit_colors.txt file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
colors = set()
|
||||
with open(args.in_out_file, 'r') as in_file:
|
||||
lines = in_file.readlines()
|
||||
for l in lines:
|
||||
colors.add(int(l))
|
||||
|
||||
fields = ['clear', 'night', 'text', 'text_night']
|
||||
with open(args.colors, 'r') as colors_file:
|
||||
tr_colors = json.load(colors_file)
|
||||
for name, color_info in tr_colors['colors'].items():
|
||||
for field in fields:
|
||||
if field in color_info:
|
||||
colors.add(int(color_info[field], 16))
|
||||
|
||||
with open(args.in_out_file, 'w') as out_file:
|
||||
for c in sorted(colors):
|
||||
out_file.write(str(c) + os.linesep)
|
||||
429
tools/python/transit/transit_graph_generator.py
Executable file
429
tools/python/transit/transit_graph_generator.py
Executable file
|
|
@ -0,0 +1,429 @@
|
|||
#!/usr/bin/env python3
|
||||
# Generates transit graph for MWM transit section generator.
|
||||
# Also shows preview of transit scheme lines.
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import numpy as np
|
||||
import os.path
|
||||
|
||||
import bezier_curves
|
||||
import transit_color_palette
|
||||
|
||||
|
||||
class OsmIdCode:
|
||||
NODE = 0x4000000000000000
|
||||
WAY = 0x8000000000000000
|
||||
RELATION = 0xC000000000000000
|
||||
RESET = ~(NODE | WAY | RELATION)
|
||||
|
||||
TYPE2CODE = {
|
||||
'n': NODE,
|
||||
'r': RELATION,
|
||||
'w': WAY
|
||||
}
|
||||
|
||||
|
||||
def get_extended_osm_id(osm_id, osm_type):
|
||||
try:
|
||||
return str(osm_id | OsmIdCode.TYPE2CODE[osm_type[0]])
|
||||
except KeyError:
|
||||
raise ValueError('Unknown OSM type: ' + osm_type)
|
||||
|
||||
|
||||
def get_line_id(road_id, line_index):
|
||||
return road_id << 4 | line_index
|
||||
|
||||
|
||||
def get_interchange_node_id(min_stop_id):
|
||||
return 1 << 62 | min_stop_id
|
||||
|
||||
|
||||
def clamp(value, min_value, max_value):
|
||||
return max(min(value, max_value), min_value)
|
||||
|
||||
|
||||
def get_mercator_point(lat, lon):
|
||||
lat = clamp(lat, -86.0, 86.0)
|
||||
sin_x = math.sin(math.radians(lat))
|
||||
y = math.degrees(0.5 * math.log((1.0 + sin_x) / (1.0 - sin_x)))
|
||||
y = clamp(y, -180, 180)
|
||||
return {'x': lon, 'y': y}
|
||||
|
||||
|
||||
class TransitGraphBuilder:
|
||||
def __init__(self, input_data, transit_colors, points_per_curve=100, alpha=0.5):
|
||||
self.palette = transit_color_palette.Palette(transit_colors)
|
||||
self.input_data = input_data
|
||||
self.points_per_curve = points_per_curve
|
||||
self.alpha = alpha
|
||||
self.networks = []
|
||||
self.lines = []
|
||||
self.stops = {}
|
||||
self.interchange_nodes = set()
|
||||
self.transfers = {}
|
||||
self.gates = {}
|
||||
self.edges = []
|
||||
self.segments = {}
|
||||
self.shapes = []
|
||||
self.transit_graph = None
|
||||
self.matched_colors = {}
|
||||
self.stop_names = {}
|
||||
|
||||
def __get_average_stops_point(self, stop_ids):
|
||||
"""Returns an average position of the stops."""
|
||||
count = len(stop_ids)
|
||||
if count == 0:
|
||||
raise ValueError('Average stops point calculation failed: the list of stop id is empty.')
|
||||
average_point = [0, 0]
|
||||
for stop_id in stop_ids:
|
||||
point = self.__get_stop(stop_id)['point']
|
||||
average_point[0] += point['x']
|
||||
average_point[1] += point['y']
|
||||
return [average_point[0] / count, average_point[1] / count]
|
||||
|
||||
def __add_gate(self, osm_id, is_entrance, is_exit, point, weight, stop_id):
|
||||
"""Creates a new gate or adds information to the existing with the same weight."""
|
||||
if (osm_id, weight) in self.gates:
|
||||
gate_ref = self.gates[(osm_id, weight)]
|
||||
if stop_id not in gate_ref['stop_ids']:
|
||||
gate_ref['stop_ids'].append(stop_id)
|
||||
gate_ref['entrance'] |= is_entrance
|
||||
gate_ref['exit'] |= is_exit
|
||||
return
|
||||
gate = {'osm_id': osm_id,
|
||||
'point': point,
|
||||
'weight': weight,
|
||||
'stop_ids': [stop_id],
|
||||
'entrance': is_entrance,
|
||||
'exit': is_exit
|
||||
}
|
||||
self.gates[(osm_id, weight)] = gate
|
||||
|
||||
def __get_interchange_node(self, stop_id):
|
||||
"""Returns the existing interchange node or creates a new one."""
|
||||
for node_stops in self.interchange_nodes:
|
||||
if stop_id in node_stops:
|
||||
return node_stops
|
||||
return (stop_id,)
|
||||
|
||||
def __get_stop(self, stop_id):
|
||||
"""Returns the stop or the interchange node."""
|
||||
if stop_id in self.stops:
|
||||
return self.stops[stop_id]
|
||||
return self.transfers[stop_id]
|
||||
|
||||
def __check_line_title(self, line, route_name):
|
||||
"""Formats correct line name."""
|
||||
if line['title']:
|
||||
return
|
||||
name = route_name if route_name else line['number']
|
||||
if len(line['stop_ids']) > 1:
|
||||
first_stop = self.stop_names[line['stop_ids'][0]]
|
||||
last_stop = self.stop_names[line['stop_ids'][-1]]
|
||||
if first_stop and last_stop:
|
||||
line['title'] = u'{0}: {1} - {2}'.format(name, first_stop, last_stop)
|
||||
return
|
||||
line['title'] = name
|
||||
|
||||
def __read_stops(self):
|
||||
"""Reads stops, their exits and entrances."""
|
||||
for stop_item in self.input_data['stops']:
|
||||
stop = {}
|
||||
stop['id'] = stop_item['id']
|
||||
stop['osm_id'] = get_extended_osm_id(stop_item['osm_id'], stop_item['osm_type'])
|
||||
if 'zone_id' in stop_item:
|
||||
stop['zone_id'] = stop_item['zone_id']
|
||||
stop['point'] = get_mercator_point(stop_item['lat'], stop_item['lon'])
|
||||
stop['line_ids'] = []
|
||||
# TODO: Save stop names stop_item['name'] and stop_item['int_name'] for text anchors calculation.
|
||||
stop['title_anchors'] = []
|
||||
self.stops[stop['id']] = stop
|
||||
self.stop_names[stop['id']] = stop_item['name']
|
||||
|
||||
for entrance_item in stop_item['entrances']:
|
||||
ex_id = get_extended_osm_id(entrance_item['osm_id'], entrance_item['osm_type'])
|
||||
point = get_mercator_point(entrance_item['lat'], entrance_item['lon'])
|
||||
self.__add_gate(ex_id, True, False, point, entrance_item['distance'], stop['id'])
|
||||
|
||||
for exit_item in stop_item['exits']:
|
||||
ex_id = get_extended_osm_id(exit_item['osm_id'], exit_item['osm_type'])
|
||||
point = get_mercator_point(exit_item['lat'], exit_item['lon'])
|
||||
self.__add_gate(ex_id, False, True, point, exit_item['distance'], stop['id'])
|
||||
|
||||
def __read_transfers(self):
|
||||
"""Reads transfers between stops."""
|
||||
for transfer_item in self.input_data['transfers']:
|
||||
edge = {'stop1_id': transfer_item[0],
|
||||
'stop2_id': transfer_item[1],
|
||||
'weight': transfer_item[2],
|
||||
'transfer': True
|
||||
}
|
||||
self.edges.append(copy.deepcopy(edge))
|
||||
edge['stop1_id'], edge['stop2_id'] = edge['stop2_id'], edge['stop1_id']
|
||||
self.edges.append(edge)
|
||||
|
||||
def __read_networks(self):
|
||||
"""Reads networks and routes."""
|
||||
for network_item in self.input_data['networks']:
|
||||
network_id = network_item['agency_id']
|
||||
network = {'id': network_id,
|
||||
'title': network_item['network']}
|
||||
self.networks.append(network)
|
||||
|
||||
for route_item in network_item['routes']:
|
||||
line_index = 0
|
||||
# Create a line for each itinerary.
|
||||
for line_item in route_item['itineraries']:
|
||||
line_stops = line_item['stops']
|
||||
line_id = get_line_id(route_item['route_id'], line_index)
|
||||
line = {'id': line_id,
|
||||
'title': line_item.get('name', ''),
|
||||
'type': route_item['type'],
|
||||
'network_id': network_id,
|
||||
'number': route_item['ref'],
|
||||
'interval': line_item['interval'],
|
||||
'stop_ids': []
|
||||
}
|
||||
line['color'] = self.__match_color(route_item.get('colour', ''), route_item.get('casing', ''))
|
||||
|
||||
# TODO: Add processing of line_item['shape'] when this data will be available.
|
||||
# TODO: Add processing of line_item['trip_ids'] when this data will be available.
|
||||
|
||||
# Create an edge for each connection of stops.
|
||||
for i in range(len(line_stops)):
|
||||
stop1 = line_stops[i]
|
||||
line['stop_ids'].append(stop1[0])
|
||||
self.stops[stop1[0]]['line_ids'].append(line_id)
|
||||
if i + 1 < len(line_stops):
|
||||
stop2 = line_stops[i + 1]
|
||||
edge = {'stop1_id': stop1[0],
|
||||
'stop2_id': stop2[0],
|
||||
'weight': stop2[1] - stop1[1],
|
||||
'transfer': False,
|
||||
'line_id': line_id,
|
||||
'shape_ids': []
|
||||
}
|
||||
self.edges.append(edge)
|
||||
|
||||
self.__check_line_title(line, route_item.get('name', ''))
|
||||
self.lines.append(line)
|
||||
line_index += 1
|
||||
|
||||
def __match_color(self, color_str, casing_str):
|
||||
if color_str is None or len(color_str) == 0:
|
||||
return self.palette.get_default_color()
|
||||
if casing_str is None:
|
||||
casing_str = ''
|
||||
matched_colors_key = color_str + "/" + casing_str
|
||||
if matched_colors_key in self.matched_colors:
|
||||
return self.matched_colors[matched_colors_key]
|
||||
c = self.palette.get_nearest_color(color_str, casing_str, self.matched_colors.values())
|
||||
if c != self.palette.get_default_color():
|
||||
self.matched_colors[matched_colors_key] = c
|
||||
return c
|
||||
|
||||
def __generate_transfer_nodes(self):
|
||||
"""Merges stops into transfer nodes."""
|
||||
for edge in self.edges:
|
||||
if edge['transfer']:
|
||||
node1 = self.__get_interchange_node(edge['stop1_id'])
|
||||
node2 = self.__get_interchange_node(edge['stop2_id'])
|
||||
merged_node = tuple(sorted(set(node1 + node2)))
|
||||
self.interchange_nodes.discard(node1)
|
||||
self.interchange_nodes.discard(node2)
|
||||
self.interchange_nodes.add(merged_node)
|
||||
|
||||
for node_stop_ids in self.interchange_nodes:
|
||||
point = self.__get_average_stops_point(node_stop_ids)
|
||||
transfer = {'id': get_interchange_node_id(self.stops[node_stop_ids[0]]['id']),
|
||||
'stop_ids': list(node_stop_ids),
|
||||
'point': {'x': point[0], 'y': point[1]},
|
||||
'title_anchors': []
|
||||
}
|
||||
|
||||
for stop_id in node_stop_ids:
|
||||
self.stops[stop_id]['transfer_id'] = transfer['id']
|
||||
|
||||
self.transfers[transfer['id']] = transfer
|
||||
|
||||
def __collect_segments(self):
|
||||
"""Prepares collection of segments for shapes generation."""
|
||||
# Each line divided on segments by its stops and transfer nodes.
|
||||
# Merge equal segments from different lines into a single one and collect adjacent stops of that segment.
|
||||
# Average positions of these stops will be used as guide points for a curve generation.
|
||||
for line in self.lines:
|
||||
prev_seg = None
|
||||
prev_id1 = None
|
||||
for i in range(len(line['stop_ids']) - 1):
|
||||
node1 = self.stops[line['stop_ids'][i]]
|
||||
node2 = self.stops[line['stop_ids'][i + 1]]
|
||||
id1 = node1.get('transfer_id', node1['id'])
|
||||
id2 = node2.get('transfer_id', node2['id'])
|
||||
if id1 == id2:
|
||||
continue
|
||||
seg = tuple(sorted([id1, id2]))
|
||||
if seg not in self.segments:
|
||||
self.segments[seg] = {'guide_points': {id1: set(), id2: set()}}
|
||||
if prev_seg is not None:
|
||||
self.segments[seg]['guide_points'][id1].add(prev_id1)
|
||||
self.segments[prev_seg]['guide_points'][id1].add(id2)
|
||||
prev_seg = seg
|
||||
prev_id1 = id1
|
||||
|
||||
def __generate_shapes_for_segments(self):
|
||||
"""Generates a curve for each connection of two stops / transfer nodes."""
|
||||
for (id1, id2), info in self.segments.items():
|
||||
point1 = [self.__get_stop(id1)['point']['x'], self.__get_stop(id1)['point']['y']]
|
||||
point2 = [self.__get_stop(id2)['point']['x'], self.__get_stop(id2)['point']['y']]
|
||||
|
||||
if info['guide_points'][id1]:
|
||||
guide1 = self.__get_average_stops_point(info['guide_points'][id1])
|
||||
else:
|
||||
guide1 = [2 * point1[0] - point2[0], 2 * point1[1] - point2[1]]
|
||||
|
||||
if info['guide_points'][id2]:
|
||||
guide2 = self.__get_average_stops_point(info['guide_points'][id2])
|
||||
else:
|
||||
guide2 = [2 * point2[0] - point1[0], 2 * point2[1] - point1[1]]
|
||||
|
||||
curve_points = bezier_curves.segment_to_Catmull_Rom_curve(guide1, point1, point2, guide2,
|
||||
self.points_per_curve, self.alpha)
|
||||
info['curve'] = np.array(curve_points)
|
||||
|
||||
polyline = []
|
||||
for point in curve_points:
|
||||
polyline.append({'x': point[0], 'y': point[1]})
|
||||
|
||||
shape = {'id': {'stop1_id': id1, 'stop2_id': id2},
|
||||
'polyline': polyline}
|
||||
self.shapes.append(shape)
|
||||
|
||||
def __assign_shapes_to_edges(self):
|
||||
"""Assigns a shape to each non-transfer edge."""
|
||||
for edge in self.edges:
|
||||
if not edge['transfer']:
|
||||
stop1 = self.stops[edge['stop1_id']]
|
||||
stop2 = self.stops[edge['stop2_id']]
|
||||
id1 = stop1.get('transfer_id', stop1['id'])
|
||||
id2 = stop2.get('transfer_id', stop2['id'])
|
||||
seg = tuple(sorted([id1, id2]))
|
||||
if seg in self.segments:
|
||||
edge['shape_ids'].append({'stop1_id': seg[0], 'stop2_id': seg[1]})
|
||||
|
||||
def __create_scheme_shapes(self):
|
||||
self.__collect_segments()
|
||||
self.__generate_shapes_for_segments()
|
||||
self.__assign_shapes_to_edges()
|
||||
|
||||
def build(self):
|
||||
if self.transit_graph is not None:
|
||||
return self.transit_graph
|
||||
|
||||
self.__read_stops()
|
||||
self.__read_transfers()
|
||||
self.__read_networks()
|
||||
self.__generate_transfer_nodes()
|
||||
self.__create_scheme_shapes()
|
||||
|
||||
self.transit_graph = {'networks': self.networks,
|
||||
'lines': self.lines,
|
||||
'gates': list(self.gates.values()),
|
||||
'stops': list(self.stops.values()),
|
||||
'transfers': list(self.transfers.values()),
|
||||
'shapes': self.shapes,
|
||||
'edges': self.edges}
|
||||
return self.transit_graph
|
||||
|
||||
def show_preview(self):
|
||||
import matplotlib.pyplot as plt
|
||||
for (s1, s2), info in self.segments.items():
|
||||
plt.plot(info['curve'][:, 0], info['curve'][:, 1], 'g')
|
||||
for stop in self.stops.values():
|
||||
if 'transfer_id' in stop:
|
||||
point = self.transfers[stop['transfer_id']]['point']
|
||||
size = 60
|
||||
color = 'r'
|
||||
else:
|
||||
point = stop['point']
|
||||
if len(stop['line_ids']) > 2:
|
||||
size = 40
|
||||
color = 'b'
|
||||
else:
|
||||
size = 20
|
||||
color = 'g'
|
||||
plt.scatter([point['x']], [point['y']], size, color)
|
||||
plt.show()
|
||||
|
||||
def show_color_maching_table(self, title, colors_ref_table):
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, aspect='equal')
|
||||
plt.title(title)
|
||||
sz = 1.0 / (2.0 * len(self.matched_colors))
|
||||
delta_y = sz * 0.5
|
||||
for c in self.matched_colors:
|
||||
tokens = c.split('/')
|
||||
if len(tokens[1]) == 0:
|
||||
tokens[1] = tokens[0]
|
||||
ax.add_patch(patches.Rectangle((sz, delta_y), sz, sz, facecolor="#" + tokens[0], edgecolor="#" + tokens[1]))
|
||||
rect_title = tokens[0]
|
||||
if tokens[0] != tokens[1]:
|
||||
rect_title += "/" + tokens[1]
|
||||
ax.text(2.5 * sz, delta_y, rect_title + " -> ")
|
||||
ref_color = colors_ref_table[self.matched_colors[c]]
|
||||
ax.add_patch(patches.Rectangle((0.3 + sz, delta_y), sz, sz, facecolor="#" + ref_color))
|
||||
ax.text(0.3 + 2.5 * sz, delta_y, ref_color + " (" + self.matched_colors[c] + ")")
|
||||
delta_y += sz * 2.0
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('input_file', help='input file name of transit data')
|
||||
parser.add_argument('output_file', nargs='?', help='output file name of generated graph')
|
||||
default_colors_path = os.path.dirname(os.path.abspath(__file__)) + '/../../../data/transit_colors.txt'
|
||||
parser.add_argument('-c', '--colors', type=str, default=default_colors_path,
|
||||
help='transit colors file COLORS_FILE_PATH', metavar='COLORS_FILE_PATH')
|
||||
parser.add_argument('-p', '--preview', action="store_true", default=False,
|
||||
help="show preview of the transit scheme")
|
||||
parser.add_argument('-m', '--matched_colors', action="store_true", default=False,
|
||||
help="show the matched colors table")
|
||||
|
||||
|
||||
parser.add_argument('-a', '--alpha', type=float, default=0.5, help='the curves generator parameter value ALPHA',
|
||||
metavar='ALPHA')
|
||||
parser.add_argument('-n', '--num', type=int, default=100, help='the number NUM of points in a generated curve',
|
||||
metavar='NUM')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.input_file, 'r') as input_file:
|
||||
data = json.load(input_file)
|
||||
|
||||
with open(args.colors, 'r') as colors_file:
|
||||
colors = json.load(colors_file)
|
||||
|
||||
transit = TransitGraphBuilder(data, colors, args.num, args.alpha)
|
||||
result = transit.build()
|
||||
|
||||
output_file = args.output_file
|
||||
head, tail = os.path.split(os.path.abspath(args.input_file))
|
||||
name, extension = os.path.splitext(tail)
|
||||
if output_file is None:
|
||||
output_file = os.path.join(head, name + '.transit' + extension)
|
||||
with open(output_file, 'w') as json_file:
|
||||
result_data = json.dumps(result, ensure_ascii=False, indent='\t', sort_keys=True, separators=(',', ':'))
|
||||
json_file.write(result_data)
|
||||
print('Transit graph generated:', output_file)
|
||||
|
||||
if args.preview:
|
||||
transit.show_preview()
|
||||
|
||||
if args.matched_colors:
|
||||
colors_ref_table = {}
|
||||
for color_name, color_info in colors['colors'].items():
|
||||
colors_ref_table[color_name] = color_info['clear']
|
||||
transit.show_color_maching_table(name, colors_ref_table)
|
||||
Loading…
Add table
Add a link
Reference in a new issue