Source code for libadalina_core.graph_extraction.readers.reader
from abc import abstractmethod
from enum import Enum
import geopandas as gpd
from libadalina_core.exceptions.input_file_exception import InputFileException
[docs]
class RoadTypes(Enum):
"""Enum representing different types of roads for filtering purposes."""
ALL = 'all'
"""Keep all roads."""
CAR_ONLY = 'only_car'
"""Keep only roads accessible by car."""
MAIN_ROADS = 'main_roads'
"""Keep only main roads (motorways, trunks, primary)."""
[docs]
class MandatoryColumns(Enum):
"""Enum representing mandatory columns required in the input DataFrame."""
id = 'id'
road_name = 'name'
oneway = 'oneway'
[docs]
class OneWay(Enum):
"""Enum representing road directions."""
Forward = 'forward'
"""Follows the direction of the geometry."""
Backward = 'backward'
"""Goes against the direction of the geometry."""
Both = 'both'
"""Both directions are allowed."""
class MapReader:
def __init__(self, road_types: RoadTypes = RoadTypes.ALL):
"""
Initialize the MapReader with the specified road type filter.
Parameters
----------
road_types : RoadTypes, optional
The type of roads to keep. Default is RoadTypes.ALL, which keeps all roads.
"""
self._road_types = road_types
@abstractmethod
def _filter_roads(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""
Filter roads based on the specified road type.
This method should be implemented by subclasses to apply specific filtering logic.
"""
raise NotImplementedError("Subclasses must implement this method.")
def map_and_reduce(self, gdf: gpd.GeoDataFrame, column_map: dict[MandatoryColumns, str]) -> gpd.GeoDataFrame:
"""
Remap column names and project only mandatory columns.
Parameters
----------
gdf : geopandas.GeoDataFrame
The input GeoDataFrame to be processed.
column_map : dict[MandatoryColumns, str]
The column names mapping: keys are the expected column names (as defined in the MandatoryColumns enum),
values are the actual column names in the input GeoDataFrame.
Returns
-------
geopandas.GeoDataFrame
A GeoDataFrame containing only the mandatory columns with standardized names.
"""
for key, value in column_map.items():
gdf[key.value] = gdf[value]
gdf = gdf[['geometry'] + [c.value for c in MandatoryColumns]]
for c in MandatoryColumns:
if c.value not in gdf.columns:
raise InputFileException(f"missing column {c.value} in dataframe")
return gdf