import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
routes_url = "https://raw.githubusercontent.com/jpatokal/openflights/master/data/routes.dat"
routes_df = pd.read_csv(routes_url, header=None)
routes_df.columns = ["airline", "id", "source", "source_id", "target", "target_id", "codeshare", "stops",
"equipments"]
routes_df = routes_df[["airline", "id", "source", "source_id", "target", "target_id"]]
routes_df.head()
airport_url = "https://raw.githubusercontent.com/jpatokal/openflights/master/data/airports.dat"
airport_df = pd.read_csv(airport_url, header=None)
airport_df.columns = ["airport_id", "name", "city", "country", "iata", "icao", "lat", "long", "altitude", "tmz",
"dst", "tzt", "type", "source"]
airport_df = airport_df[["airport_id", "iata", "icao", "country", "city", "name"]]
airport_df.head()
us_airport_df = airport_df[airport_df.country == "United States"]
us_airport_df.head()
kept_iata = set(us_airport_df.iata.tolist())
kept_icao = set(us_airport_df.icao.tolist())
kept_codes = kept_iata.union(kept_icao)
us_routes = routes_df[routes_df.source.isin(kept_codes)]
us_routes = us_routes[us_routes.target.isin(kept_codes)]
len(us_routes)
# keep only topK most popular
from collections import Counter
topK = 100
cnt = Counter(us_routes.source.tolist())
cnt.update(us_routes.target.tolist())
kept_codes = sorted([(k, v) for k, v in cnt.items()], key=lambda x: x[1], reverse=True)
kept_codes = [k for k, _ in kept_codes[:topK]]
kept_routes = us_routes[us_routes.source.isin(kept_codes)]
kept_routes = kept_routes[kept_routes.target.isin(kept_codes)]
networks = {}
airlines = list(set(kept_routes.airline))
airports = list(set(kept_routes.source).union(set(kept_routes.target)))
airport2ind = {k: i for i, k in enumerate(airports)}
for airl in airlines:
curr = kept_routes[kept_routes.airline == airl].copy()
net = np.zeros((len(airports), len(airports)))
for ind, row in curr.iterrows():
net[airport2ind[row.source], airport2ind[row.target]] = 1
net[airport2ind[row.target], airport2ind[row.source]] = 1
networks[airl] = net.astype(np.int32)
fig, axes = plt.subplots(7, 8, figsize=(20, 20))
axes = axes.flat
for i, air in enumerate(airlines):
axes[i].set_title(air)
axes[i].imshow(networks[air])
axes[i].set_xticks([])
axes[i].set_yticks([])
for i in range(52, 56):
axes[i].set_visible(False)
plt.show()
# basefile = "data/airline_network/{0}.csv"
# for airline, net in networks.items():
# file = basefile.format(airline)
# np.savetxt(file, net, fmt='%i', delimiter=",")