#!/bin/bash

KWCOCO_BUNDLE_DPATH=$HOME/Projects/SMART/smart_watch_dvc/extern/onera_2018/
TRAIN_FPATH=$KWCOCO_BUNDLE_DPATH/onera_train.kwcoco.json
VALI_FPATH=$KWCOCO_BUNDLE_DPATH/onera_test.kwcoco.json
TEST_FPATH=$KWCOCO_BUNDLE_DPATH/onera_test.kwcoco.json
CHANNELS="B02|B03|B04"
INITIAL_STATE="noop"
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m geowatch.tasks.fusion.fit_lightning fit \
    --data.batch_size=8 \
    --data.channels="$CHANNELS" \
    --data.chip_dims=64 \
    --data.chip_overlap=0.25 \
    --data.neg_to_pos_ratio=0.25 \
    --data.num_workers=16 \
    --data.resample_invalid_frames=0 \
    --data.use_cloudmask=0 \
    --data.set_cover_algo=approx \
    --data.temporal_dropout=0.5 \
    --data.test_dataset="$TEST_FPATH" \
    --data.time_sampling="soft2+distribute" \
    --data.time_span=1m \
    --data.time_steps=3 \
    --data.train_dataset="$TRAIN_FPATH" \
    --data.use_centered_positives=false \
    --data.vali_dataset="$VALI_FPATH" \
    --data.sqlview=false \
    --model=watch.tasks.fusion.methods.SequenceAwareModel \
    --model.stream_channels=1 \
    --model.class_loss='dicefocal' \
    --model.decoder=mlp \
    --model.dropout=0.1 \
    --model.global_change_weight=1.00 \
    --model.global_class_weight=1.00 \
    --model.global_saliency_weight=1.00 \
    --model.learning_rate=1e-4 \
    --model.name=$EXPERIMENT_NAME \
    --model.optimizer=AdamW \
    --model.render_outputs=true \
    --model.saliency_loss='focal' \
    --model.tokenizer="linconv" \
    --model.weight_decay=0 \
    --trainer.accelerator="gpu" \
    --trainer.devices=1 \
    --trainer.precision=16 \
    --trainer.default_root_dir="lightning_logs" \
    --trainer.max_steps=5
    # --trainer.accumulate_grad_batches=16 \
    # --trainer.num_sanity_val_steps=2 \
    # --trainer.fast_dev_run=5
    # --trainer.strategy=deepspeed_stage_2_offload \
    # --trainer.track_grad_norm=2 \
    # --trainer.fast_dev_run=5 \
    # --trainer.amp_backend=apex \
    # --init="$INITIAL_STATE" \
    # --patience=160 \
    # --draw_interval=5min \
    # --num_draw=1 \
