#!/usr/bin/env python
"""Segment lidar.

Usage:
  segment-lidar.py <cloud_path> <ckpt_path> [--resolution=<px>] [--cpu]
  segment-lidar.py (-h | --help)
  segment-lidar.py --version

Arguments:
  cloud_path  Path to the .las/laz pointcloud.
  ckpt_path   Path to SAM model weights.

Options:
  --resolution=<px> The resolution of the image in units per pixel [default: 0.25].
  --cpu             Whether to enable CPU (disable GPU) or not.

"""
import os
from docopt import docopt
import segment_lidar.samlidar


def main(cloud_path: str, ckpt_path: str, resolution: float=0.25, cpu: bool=False):
    resolution = float(resolution)
    if cpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""

    # get data path and model type
    data_path, data_ext = os.path.splitext(cloud_path)
    data_path = os.path.join(os.path.dirname(data_path), "segment-lidar", os.path.basename(data_path))
    image_path = data_path + "_raster.tif"
    labels_path = data_path + "_labeled.tif"
    save_path = data_path + "_segmented.las"
    model_type_idx = ckpt_path.find("vit_")
    model_type = ckpt_path[model_type_idx:model_type_idx+5]
    # run segment-lidar
    model = segment_lidar.samlidar.SamLidar(ckpt_path, model_type=model_type, resolution=resolution)
    points = model.read(cloud_path)
    cloud, non_ground, ground = model.csf(points)
    os.makedirs(os.path.dirname(data_path), exist_ok=True)
    labels, *_ = model.segment(points=cloud, image_path=image_path, labels_path=labels_path)
    model.write(points=points, non_ground=non_ground, ground=ground, segment_ids=labels, save_path=save_path)


if __name__ == "__main__":
    args = docopt(__doc__, version="Segment lidar 0.1.5")
    # preprocess args to remove pep chars like -- and []...
    old_args = args.copy()
    for key, val in old_args.items():
        args.pop(key)
        if not (("--help" in key) | ("--version" in key) | ("-h" in key) | ("-v" in key) | (val is None)):
            new_key = key.strip("<>--")
            args[new_key] = val
    main(**args)
