diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6fc0ebc19afa539940ca9f6b7f3cedb6e8973eea..60a8f07ed75e20cc3a7852dbea153969928ac6db 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -142,6 +142,7 @@ stages: git add dev/public_gpg_key + .gpgsign_template: &gpgsign_template <<: - *common_template @@ -231,6 +232,9 @@ stages: # Define the actual jobs + + + # --------------- # Python 3.8 Jobs @@ -238,13 +242,13 @@ build/cp38-cp38-linux: <<: - *build_template image: - python:3.8 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.8 test_full/cp38-cp38-linux: <<: - *test_full_template image: - python:3.8 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.8 # for universal builds we only need to gpg sign once @@ -252,13 +256,13 @@ gpgsign/cp38-cp38-linux: <<: - *gpgsign_template image: - python:3.8 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.8 deploy/cp38-cp38-linux: <<: - *deploy_template image: - python:3.8 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.8 # --------------- @@ -268,13 +272,13 @@ build/cp37-cp37m-linux: <<: - *build_template image: - python:3.7 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.7 test_full/cp37-cp37m-linux: <<: - *test_full_template image: - python:3.7 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.7 #gpgsign/cp37-cp37m-linux: @@ -297,13 +301,13 @@ build/cp36-cp36m-linux: <<: - *build_template image: - python:3.6 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.6 test_full/cp36-cp36m-linux: <<: - *test_full_template image: - python:3.6 + gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.6 #gpgsign/cp36-cp36m-linux: # <<: @@ -321,54 +325,67 @@ test_full/cp36-cp36m-linux: # --------------- # Python 3.5 Jobs -build/cp35-cp35m-linux: - <<: - - *build_template - image: - python:3.5 - -test_full/cp35-cp35m-linux: - <<: - - *test_full_template - image: - python:3.5 - -#gpgsign/cp35-cp35m-linux: +#build/cp35-cp35m-linux: # <<: -# - *gpgsign_template +# - *build_template # image: # python:3.5 -#deploy/cp35-cp35m-linux: +#test_full/cp35-cp35m-linux: # <<: -# - *deploy_template +# - *test_full_template # image: # python:3.5 +.__local_docker_doc__: + - | + # Docker rate limiting messed us up + # On my local machine + + docker login gitlab.kitware.com:4567 + + docker pull python:3.8 + docker pull python:3.7 + docker pull python:3.6 + + docker tag python:3.6 gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.6 + docker tag python:3.7 gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.7 + docker tag python:3.8 gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.8 -# --------------- -# Python 2.7 Jobs + docker push gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.6 + docker push gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.7 + docker push gitlab.kitware.com:4567/computer-vision/ndsampler/python:3.8 -#build/cp27-cp27mu-linux: -# <<: -# - *build_template -# image: -# python:2.7 + docker tag python:3.6 gitlab.kitware.com:4567/computer-vision/python:3.6 + docker tag python:3.7 gitlab.kitware.com:4567/computer-vision/python:3.7 + docker tag python:3.8 gitlab.kitware.com:4567/computer-vision/python:3.8 -#test_full/cp27-cp27mu-linux: -# <<: -# - *test_full_template -# image: -# python:2.7 + docker push gitlab.kitware.com:4567/computer-vision/python:3.6 + docker push gitlab.kitware.com:4567/computer-vision/python:3.7 + docker push gitlab.kitware.com:4567/computer-vision/python:3.8 -#gpgsign/cp27-cp27mu-linux: -# <<: -# - *gpgsign_template -# image: -# python:2.7 -#deploy/cp27-cp27mu-linux: -# <<: -# - *deploy_template -# image: -# python:2.7 +.__local_docker_test__: + - | + # Docker rate limiting messed us up + # On my local machine + docker run -it python:3.8 bash + + apt update -y && apt install git -y + git clone https://gitlab.kitware.com/computer-vision/ndsampler.git + cd ndsampler + + python setup.py bdist_wheel --universal + python -V # Print out python version for debugging + export PYVER=$(python -c "import sys; print('{}{}'.format(*sys.version_info[0:2]))") + pip install virtualenv + virtualenv venv$PYVER + source venv$PYVER/bin/activate + pip install pip -U + pip install pip setuptools -U + python -V # Print out python version for debugging + pip install -r requirements.txt + pip install . + + apt update && apt install libgl1-mesa-glx -y && rm -rf /var/lib/apt/lists/* + ./run_tests.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f754671c08fee8031507843c076f3bba80c3acee..64dfabcd7652a656a6ffccd9ff524631423ea32c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,22 @@ This changelog follows the specifications detailed in: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release. +## Version 0.5.12 - Unreleased + +### Changed +* Removed Python 3.5 support +* No longer using protocol 2 in the Cacher +* Better `LazyGDalFrameFile.demo` classmethod. + +### Fixed +* Bug in accessing the LRU + +## Version 0.5.11 - Released 2020-08-26 + +### Fixed +* Minor compatibility fixes for `ndsampler.CategoryTree` and `kwcoco.CategoryTree` + + ## Version 0.5.10 - Released 2020-06-25 ### Added @@ -35,7 +51,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * `batch_validate_cog` to `util_gdal` -## Version 0.5.6 - Unreleased +## Version 0.5.6 - Released 2020-08-26 ### Added * `util_lru` containing implementations of a dictionary based LRU cache. @@ -207,5 +223,3 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * Initial code for CategoryTree * Initial code for CocoDataset * Initial code for dummy detection toydata - -## Version 0.5.11 - Unreleased diff --git a/dev/bench_video_readers.py b/dev/bench_video_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e0d42c7366299a59f4d841f123ba2d443b18c8 --- /dev/null +++ b/dev/bench_video_readers.py @@ -0,0 +1,211 @@ +import pandas as pd +import ubelt as ub +import cv2 +import kwarray +import timerit + + +class CV2VideoReader(ub.NiceRepr): + def __init__(self, fpath): + self.fpath = fpath + self._cap = cv2.VideoCapture(fpath) + self._len = None + + def tell(self): + index = self._cap.get(cv2.CAP_PROP_POS_FRAMES) + return index + + def seek(self, index): + self._cap.set(cv2.CAP_PROP_POS_FRAMES, index) + + def __del__(self): + self._cap.release() + + def meta(self): + keys = [n for n in dir(cv2) if n.startswith('CAP_PROP_')] + meta = {k: self._cap.get(getattr(cv2, k)) for k in keys} + return meta + + def __len__(self): + if self._len is None: + self._len = int(self._cap.get(cv2.CAP_PROP_FRAME_COUNT)) + return self._len + + def __iter__(self): + while True: + ret, frame = self._cap.read() + if not ret: + break + yield frame + + def __getitem__(self, index): + self.seek(index) + ret, frame = self._cap.read() + if not ret: + raise IndexError(index) + return frame + + +def benchmark_video_readers(): + # video_fpath = ub.grabdata('https://download.blender.org/peach/bigbuckbunny_movies/big_buck_bunny_720p_h264.mov') + try: + import vi3o + except Exception: + vi3o = None + + video_fpath = ub.grabdata('https://download.blender.org/peach/bigbuckbunny_movies/BigBuckBunny_320x180.mp4') + video_fpath = ub.grabdata('https://file-examples-com.github.io/uploads/2018/04/file_example_MOV_1280_1_4MB.mov') + + ti = timerit.Timerit(10, bestof=3, verbose=3, unit='ms') + + video_length = len(CV2VideoReader(video_fpath)) + num_frames = min(5, video_length) + rng = kwarray.ensure_rng(0) + random_indices = rng.randint(0, video_length, size=num_frames).tolist() + + if True: + with timerit.Timer(label='open cv2') as cv2_open_timer: + cv2_video = CV2VideoReader(video_fpath) + + for timer in ti.reset('cv2 sequential access'): + cv2_video.seek(0) + with timer: + for frame, _ in zip(cv2_video, range(num_frames)): + pass + + for timer in ti.reset('cv2 random access'): + with timer: + for index in random_indices: + cv2_video[index] + + if vi3o is not None: + with timerit.Timer(label='open vi3o') as vi3o_open_timer: + vi3o_video = vi3o.Video(video_fpath) + + for timer in ti.reset('vi3o sequential access'): + with timer: + for frame, _ in zip(vi3o_video, range(num_frames)): + pass + + for timer in ti.reset('vi3o random access'): + with timer: + for index in random_indices: + vi3o_video[index] + + if True: + import decord + with timerit.Timer(label='open decord') as decord_open_timer: + decord_video = decord.VideoReader(video_fpath) + + for timer in ti.reset('decord sequential access'): + with timer: + for frame, _ in zip(decord_video, range(num_frames)): + pass + + for timer in ti.reset('decord random access'): + with timer: + for index in random_indices: + decord_video[index] + + for timer in ti.reset('decord random batch access'): + with timer: + decord_video.get_batch(random_indices) + + if True: + # One Random Access Case + + def _work_to_clear_io_caches(): + import kwimage + # Let some caches be cleared + for i in range(10): + for key in kwimage.grab_test_image.keys(): + kwimage.grab_test_image(key) + + rng = kwarray.ensure_rng(0) + for timer in ti.reset('cv2 open + one random access'): + _work_to_clear_io_caches() + with timer: + _cv2_video = CV2VideoReader(video_fpath) + index = rng.randint(0, video_length, size=1)[0] + _cv2_video[index] + + if vi3o is not None: + rng = kwarray.ensure_rng(0) + for timer in ti.reset('vi3o open + one random access'): + _work_to_clear_io_caches() + with timer: + _vi3o_video = vi3o.Video(video_fpath) + index = rng.randint(0, video_length, size=1)[0] + _vi3o_video[index] + + rng = kwarray.ensure_rng(0) + for timer in ti.reset('decord open + one random access'): + _work_to_clear_io_caches() + with timer: + _decord_video = decord.VideoReader(video_fpath) + index = rng.randint(0, video_length, size=1)[0] + _decord_video[index] + + for timer in ti.reset('cv2 open + first access'): + _work_to_clear_io_caches() + with timer: + _cv2_video = CV2VideoReader(video_fpath) + _cv2_video[0] + + if vi3o is not None: + for timer in ti.reset('vi3o open + first access'): + _work_to_clear_io_caches() + with timer: + _vi3o_video = vi3o.Video(video_fpath) + _vi3o_video[0] + + for timer in ti.reset('decord open + first access'): + _work_to_clear_io_caches() + with timer: + _decord_video = decord.VideoReader(video_fpath) + _decord_video[0] + + measures = ub.map_vals(ub.sorted_vals, ti.measures) + print('ti.measures = {}'.format(ub.repr2(measures, nl=2, align=':', precision=4))) + print('cv2_open_timer.elapsed = {!r}'.format(cv2_open_timer.elapsed)) + print('decord_open_timer.elapsed = {!r}'.format(decord_open_timer.elapsed)) + if vi3o: + print('vi3o_open_timer.elapsed = {!r}'.format(vi3o_open_timer.elapsed)) + + import kwplot + import seaborn as sns + sns.set() + kwplot.autompl() + + df = pd.DataFrame(ti.measures) + df['key'] = df.index + df['expt'] = df['key'].apply(lambda k: ' '.join(k.split(' ')[1:])) + df['module'] = df['key'].apply(lambda k: k.split(' ')[0]) + + # relmod = 'decord' + relmod = 'cv2' + for k, group in df.groupby('expt'): + measure = 'mean' + relval = group[group['module'] == relmod][measure].values.ravel() + if len(relval) > 0: + assert len(relval) == 1 + df.loc[group.index, measure + '_rel'] = group[measure] / relval + df.loc[group.index, measure + '_slower_than_' + relmod] = group[measure] / relval + df.loc[group.index, measure + '_faster_than_' + relmod] = relval / group[measure] + + fig = kwplot.figure(fnum=1, doclf=True) + ax = fig.gca() + y_key = "mean_faster_than_" + relmod + + sub_df = df.loc[~df[y_key].isnull()] + sns.barplot( + x="expt", y=y_key, data=sub_df, hue='module', ax=ax) + ax.set_title('cpu video reading benchmarks') + + +if __name__ == '__main__': + """ + CommandLine: + python ~/code/ndsampler/dev/bench_video_readers.py + """ + benchmark_video_readers() diff --git a/dev/immutable_dict.py b/dev/immutable_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..f3500a386d54de8779747ba81b96f4743e1910a6 --- /dev/null +++ b/dev/immutable_dict.py @@ -0,0 +1,97 @@ +import math +import typing +import torch + + +class DictArray: + def __init__(self, + dicts : typing.List[dict], + types : typing.Dict[str, typing.Type] = {}, + *, + batch_size : int = 1024, + string_encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le', + ints_dtype=torch.int64): + pass + + dicts = list(dicts) + numel = len(dicts) + assert numel > 0 + self.tensors = {k : t(numel) for k, t in types.items() if t != StringArray and t != IntsArray} + string_lists = {k : [None] * numel for k, t in types.items() if t == StringArray} + ints_lists = {k : [None] * numel for k, t in types.items() if t == IntsArray} + temp_lists = {k : [None] * batch_size for k in self.tensors} + for b in range(math.ceil(numel / batch_size)): + for i, t in enumerate(dicts[b * batch_size : (b + 1) * batch_size]): + for k in temp_lists: + temp_lists[k][i] = t[k] + for k in string_lists: + string_lists[k][b * batch_size + i] = t[k] + for k in ints_lists: + ints_lists[k][b * batch_size + i] = t[k] + for k, v in temp_lists.items(): + res = self.tensors[k][b * batch_size : (b + 1) * batch_size] + res.copy_(torch.as_tensor(v[:len(res)], dtype=self.tensors[k].dtype)) + self.string_arrays = {k : StringArray(v, encoding=string_encoding) for k, v in string_lists.items()} + self.ints_arrays = {k : IntsArray(v, dtype=ints_dtype) for k, v in ints_lists.items()} + + def __getitem__(self, i): + return dict( + **{k : v[i].item() for k, v in self.tensors.items()}, + **{k : v[i] for k, v in self.string_arrays.items()}, + **{k : v[i] for k, v in self.ints_arrays.items()} + ) + + def __len__(self): + return len(next(iter(self.tensors.values()))) if len(self.tensors) > 0 else len(next(iter(self.string_arrays.values()))) + + +class NamedTupleArray(DictArray): + def __init__(self, namedtuples, *args, **kwargs): + super().__init__([t._asdict() for t in namedtuples], *args, **kwargs) + self.namedtuple = type(next(iter(namedtuples))) + + def __getitem__(self, index): + return self.namedtuple(**super().__getitem__(index)) + + +class StringArray: + def __init__(self, strings : typing.List[str], encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le'): + strings = list(strings) + self.encoding = encoding + self.multiplier = dict(ascii=1, utf_16_le=2, utf_32_le=4)[encoding] + self.data = torch.ByteTensor(torch.ByteStorage.from_buffer(''.join(strings).encode(encoding))) + self.cumlen = torch.LongTensor(list(map(len, strings))).cumsum(dim=0) + assert int(self.cumlen[-1]) * self.multiplier == len(self.data), ( + f'[{encoding}] is not enough to hold characters, use a larger character class') + + def __getitem__(self, i): + return bytes(self.data[(self.cumlen[i - 1] * self.multiplier if i >= 1 else 0) : self.cumlen[i] * self.multiplier]).decode(self.encoding) + + def __len__(self): + return len(self.cumlen) + + +class IntsArray: + def __init__(self, ints, dtype=torch.int64): + tensors = [torch.as_tensor(t, dtype=dtype) for t in ints] + self.data = torch.cat(tensors) + self.cumlen = torch.tensor(list(map(len, tensors)), dtype=torch.int64).cumsum(dim=0) + + def __getitem__(self, i): + return self.data[(self.cumlen[i - 1] if i >= 1 else 0) : self.cumlen[i]] + + def __len__(self): + return len(self.cumlen) + + +def main(): + a = StringArray(['asd', 'def']) + print('len = ', len(a)) + print('data = ', list(a)) + + a = DictArray([dict(a=1, b='def'), dict(a=2, b='klm')], types=dict(a=torch.LongTensor, b=StringArray)) + print('len = ', len(a)) + print('data = ', list(a)) + +# if __name__ == '__main__': +# main() diff --git a/dev/pyqtree_bug_mwe.py b/dev/pyqtree_bug_mwe.py new file mode 100644 index 0000000000000000000000000000000000000000..c2231a6c4e5aed7803d07d3455ca56ce2e78bb53 --- /dev/null +++ b/dev/pyqtree_bug_mwe.py @@ -0,0 +1,153 @@ + + +def pyqtree_bug_mwe(): + import pyqtree + qtree = pyqtree.Index((0, 0, 600, 480)) + oob_tlbr_box = [939, 169, 2085, 1238] + for idx in range(1, 11): + qtree.insert(idx, oob_tlbr_box) + qtree.insert(11, oob_tlbr_box) + + import pyqtree + qtree = pyqtree.Index((0, 0, 600, 600)) + oob_tlbr_box = [500, 500, 1000, 1000] + for idx in range(1, 11): + print('Insert idx = {!r}'.format(idx)) + qtree.insert(idx, oob_tlbr_box) + idx = 11 + print('Insert idx = {!r}'.format(idx)) + qtree.insert(idx, oob_tlbr_box) + + +def pyqtree_bug_test_cases(): + """ + """ + import ubelt as ub + # Test multiple cases + def basis_product(basis): + """ + Args: + basis (Dict[str, List[T]]): list of values for each axes + + Yields: + Dict[str, T] - points in the grid + """ + import itertools as it + keys = list(basis.keys()) + for vals in it.product(*basis.values()): + kw = ub.dzip(keys, vals) + yield kw + + height, width = 600, 600 + # offsets = [-100, -50, 0, 50, 100] + offsets = [-100, -10, 0, 10, 100] + # offsets = [-100, 0, 100] + x_edges = [0, width] + y_edges = [0, height] + # x_edges = [width] + # y_edges = [height] + basis = { + 'tl_x': [e + p for p in offsets for e in x_edges], + 'tl_y': [e + p for p in offsets for e in y_edges], + 'br_x': [e + p for p in offsets for e in x_edges], + 'br_y': [e + p for p in offsets for e in y_edges], + } + + # Collect and label valid cases + # M = in bounds (middle) + # T = out of bounds on the top + # L = out of bounds on the left + # B = out of bounds on the bottom + # R = out of bounds on the right + cases = [] + for item in basis_product(basis): + bbox = (item['tl_x'], item['tl_y'], item['br_x'], item['br_y']) + x1, y1, x2, y2 = bbox + if x1 < x2 and y1 < y2: + parts = [] + + if x1 < 0: + parts.append('x1=L') + elif x1 < width: + parts.append('x1=M') + else: + parts.append('x1=R') + + if x2 <= 0: + parts.append('x2=L') + elif x2 <= width: + parts.append('x2=M') + else: + parts.append('x2=R') + + if y1 < 0: + parts.append('y1=T') + elif y1 < width: + parts.append('y1=M') + else: + parts.append('y1=B') + + if y2 <= 0: + parts.append('y2=T') + elif y2 <= width: + parts.append('y2=M') + else: + parts.append('y2=B') + + assert len(parts) == 4 + label = ','.join(parts) + cases.append((label, bbox)) + + cases = sorted(cases) + print('total cases: {}'.format(len(cases))) + + failed_cases = [] + passed_cases = [] + + # We will execute the MWE in a separate python process via the "-c" + # argument so we can programatically kill cases that hang + test_case_lines = [ + 'import pyqtree', + 'bbox, width, height = {!r}, {!r}, {!r}', + 'qtree = pyqtree.Index((0, 0, width, height))', + '[qtree.insert(idx, bbox) for idx in range(1, 11)]', + 'qtree.insert(11, bbox)', + ] + + import subprocess + for label, bbox in ub.ProgIter(cases, desc='checking case', verbose=3): + pycmd = ';'.join(test_case_lines).format(bbox, width, height) + command = 'python -c "{}"'.format(pycmd) + info = ub.cmd(command, detatch=True) + proc = info['proc'] + try: + if proc.wait(timeout=0.2) != 0: + raise AssertionError + except (subprocess.TimeoutExpired, AssertionError): + # Kill cases that hang + proc.terminate() + text = 'Failed case: {}, bbox = {!r}'.format(label, bbox) + color = 'red' + failed_cases.append((label, bbox, text)) + else: + out, err = proc.communicate() + text = 'Passed case: {}, bbox = {!r}'.format(label, bbox) + color = 'green' + passed_cases.append((label, bbox, text)) + print(ub.color_text(text, color)) + print('len(failed_cases) = {}'.format(len(failed_cases))) + print('len(passed_cases) = {}'.format(len(passed_cases))) + + passed_labels = set([t[0] for t in passed_cases]) + failed_labels = set([t[0] for t in failed_cases]) + print('passed_labels = {}'.format(ub.repr2(sorted(passed_labels)))) + print('failed_labels = {}'.format(ub.repr2(sorted(failed_labels)))) + print('overlap = {}'.format(set(passed_labels) & set(failed_labels))) + + +if __name__ == '__main__': + """ + CommandLine: + python ~/code/ndsampler/dev/pyqtree_bug_mwe.py + """ + pyqtree_bug_test_cases() diff --git a/docs/todo.txt b/docs/todo.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c436a24357ac4fd838982aadbb6c35003ea66f6 --- /dev/null +++ b/docs/todo.txt @@ -0,0 +1,4 @@ +- [ ] PostGIS backend would be nice (with support for spatial intersection queries) +- [ ] 1D backend +- [ ] 3D backend +- [ ] Hashed Directory Structure for the cog cache. diff --git a/ndsampler/__init__.py b/ndsampler/__init__.py index b64389bd0d3ae3c84f85fde8f3e7ed201d7675d2..cb4905a742eb06c2a44e95d3fa0cbeded191300b 100644 --- a/ndsampler/__init__.py +++ b/ndsampler/__init__.py @@ -1,7 +1,7 @@ """ mkinit ~/code/ndsampler/ndsampler/__init__.py -w """ -__version__ = '0.5.11' +__version__ = '0.5.12' from ndsampler.utils.util_misc import (HashIdentifiable, stats_dict,) diff --git a/ndsampler/abstract_frames.py b/ndsampler/abstract_frames.py index 5715189c930dff2095db60b6bd459fba24a7b3b0..e2ba64992979b60314fecf0fe7762714bafb7c66 100644 --- a/ndsampler/abstract_frames.py +++ b/ndsampler/abstract_frames.py @@ -309,6 +309,8 @@ class Frames(object): """ if image_id not in self.id_to_hashid: # Compute the hash if we it does not exist yet + # TODO: We may be able to take advantage of DVC's cache here if we + # are in that context. gpath = self._lookup_gpath(image_id) if self.hashid_mode == 'PATH': # Hash the full path to the image data @@ -428,22 +430,25 @@ class Frames(object): return data def _load_image_full(self, image_id): - if image_id in self._lru: - return self._lru[image_id] + if self._lru is not None: + if image_id in self._lru: + return self._lru[image_id] import kwimage gpath = self._lookup_gpath(image_id) raw_data = kwimage.imread(gpath) - self._lru[image_id] = raw_data + if self._lru is not None: + self._lru[image_id] = raw_data return raw_data def _load_image_npy(self, image_id): """ Returns a memmapped reference to the entire image """ - if image_id in self._lru: - return self._lru[image_id] + if self._lru is not None: + if image_id in self._lru: + return self._lru[image_id] gpath = self._lookup_gpath(image_id) gpath, cache_gpath = self._gnames(image_id, mode='npy') @@ -493,15 +498,17 @@ class Frames(object): print('\n\n') raise - self._lru[image_id] = file + if self._lru is not None: + self._lru[image_id] = file return file def _load_image_cog(self, image_id): """ Returns a special array-like object with a COG GeoTIFF backend """ - if image_id in self._lru: - return self._lru[image_id] + if self._lru is not None: + if image_id in self._lru: + return self._lru[image_id] gpath, cache_gpath = self._gnames(image_id, mode='cog') cog_gpath = cache_gpath @@ -570,7 +577,8 @@ class Frames(object): print('') file = util_gdal.LazyGDalFrameFile(cog_gpath) - self._lru[image_id] = file + if self._lru is not None: + self._lru[image_id] = file return file @staticmethod @@ -594,11 +602,13 @@ class Frames(object): _locked_cache_write(_npy_cache_write, gpath, cache_gpath=mem_gpath, config=config) - def prepare(self, workers=0, use_stamp=True): + def prepare(self, gids=None, workers=0, use_stamp=True): """ Precompute the cached frame conversions Args: + gids (List[int] | None): specific image ids to prepare. + If None prepare all images. workers (int, default=0): number of parallel threads for this io-bound task @@ -653,30 +663,46 @@ class Frames(object): hashid = getattr(self, 'hashid', None) # TODO: - # Add some image preprocessing ability here + # Add some image preprocessing ability here? stamp = ub.CacheStamp('prepare_frames_stamp', dpath=self.cache_dpath, cfgstr=hashid, verbose=3) - stamp.cacher.enabled = bool(hashid) and bool(use_stamp) + stamp.cacher.enabled = bool(hashid) and bool(use_stamp) and gids is None # print('frames stamp hashid = {!r}'.format(hashid)) # print('frames cache_dpath = {!r}'.format(self.cache_dpath)) # print('stamp.cacher.enabled = {!r}'.format(stamp.cacher.enabled)) + if self._backend is None: + mode = None + else: + mode = self._backend['type'] + if stamp.expired() or hashid is None: from ndsampler.utils import util_futures from concurrent import futures # Use thread mode, because we are mostly in doing io. executor = util_futures.Executor(mode='thread', max_workers=workers) with executor as executor: - job_list = [] - gids = self.image_ids - for image_id in ub.ProgIter(gids, desc='Frames: submit prepare jobs'): - gpath, cache_gpath = self._gnames(image_id) - if not exists(cache_gpath): - job = executor.submit( - self.load_image, image_id, cache=True, - noreturn=True) - job_list.append(job) + if gids is None: + gids = self.image_ids + + path_list = [ + (image_id, self._gnames(image_id, mode=mode)) + for image_id in ub.ProgIter(gids, desc='lookup cache path') + ] + cache_gpath_list = [ + (image_id, cache_gpath) + for (image_id, (gpath, cache_gpath)) in ub.ProgIter(path_list, desc='check exists') + if not exists(cache_gpath) + ] + + prog = ub.ProgIter(cache_gpath_list, + desc='Frames: submit prepare jobs') + job_list = [ + executor.submit( + self.load_image, image_id, cache=True, + noreturn=True) + for image_id, cache_gpath in prog] for job in ub.ProgIter(futures.as_completed(job_list), total=len(job_list), adjust=False, freq=1, diff --git a/ndsampler/category_tree.py b/ndsampler/category_tree.py index 16df33e1ebea0718320d5206e572ae7513cf5cf3..906d1c469db7f711d67bb45a336c85a5b8b03677 100644 --- a/ndsampler/category_tree.py +++ b/ndsampler/category_tree.py @@ -15,12 +15,12 @@ Notes from YOLO-9000: hyponym of cutlery. """ from __future__ import absolute_import, division, print_function, unicode_literals -import torch import kwarray import functools import networkx as nx import ubelt as ub -import torch.nn.functional as F +# import torch +# import torch.nn.functional as F import numpy as np from kwcoco import CategoryTree as KWCOCO_CategoryTree # raw category tree @@ -42,6 +42,8 @@ class Mixin_CategoryTree_Torch: dim (int): dimension where each index corresponds to a class Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> graph = nx.generators.gnr_graph(30, 0.3, seed=321).reverse() >>> self = CategoryTree(graph) @@ -52,6 +54,8 @@ class Mixin_CategoryTree_Torch: >>> cond_probs = torch.exp(cond_logprobs).numpy() >>> assert np.allclose(cond_probs.sum(axis=1), len(self.idx_groups)) """ + import torch + import torch.nn.functional as F cond_logprobs = torch.empty_like(class_energy) if class_energy.numel() == 0: return cond_logprobs @@ -96,6 +100,7 @@ class Mixin_CategoryTree_Torch: Log-Probability chain rule: log(P(node)) = log(P(node | parent)) + log(P(parent)) """ + import torch # The dynamic program was faster on the CPU in a dummy test case memo = {} @@ -151,6 +156,8 @@ class Mixin_CategoryTree_Torch: dim (int): dimension corresponding to classes (usually 1) Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> graph = nx.generators.gnr_graph(20, 0.3, seed=328).reverse() >>> self = CategoryTree(graph) @@ -200,6 +207,8 @@ class Mixin_CategoryTree_Torch: dim (int): dimension corresponding to classes (usually 1) Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> graph = nx.generators.gnr_graph(20, 0.3, seed=328).reverse() >>> self = CategoryTree(graph) @@ -217,11 +226,15 @@ class Mixin_CategoryTree_Torch: ... torch.allclose(child_sum, p_node) Ignore: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> class_logprobs1 = self.sink_log_softmax(class_energy, dim=1) >>> class_logprobs2 = self.source_log_softmax(class_energy, dim=1) >>> class_probs1 = torch.exp(class_logprobs1) >>> class_probs2 = torch.exp(class_logprobs2) """ + import torch + import torch.nn.functional as F class_logprobs = torch.empty_like(class_energy) leaf_idxs = sorted(self.node_to_idx[node] for node in sink_nodes(self.graph)) @@ -259,6 +272,7 @@ class Mixin_CategoryTree_Torch: def hierarchical_softmax(self, class_energy, dim): """ Convinience method which converts class-energy to final probs """ + import torch class_logprobs = self.hierarchical_log_softmax(class_energy, dim) class_probs = torch.exp(class_logprobs) return class_probs @@ -270,6 +284,7 @@ class Mixin_CategoryTree_Torch: def graph_softmax(self, class_energy, dim): """ Convinience method which converts class-energy to final probs """ + import torch class_logprobs = self.hierarchical_log_softmax(class_energy, dim) class_probs = torch.exp(class_logprobs) return class_probs @@ -279,6 +294,7 @@ class Mixin_CategoryTree_Torch: """ Combines hierarchical_log_softmax and nll_loss in a single function """ + import torch.nn.functional as F class_logprobs = self.hierarchical_log_softmax(class_energy, dim=1) loss = F.nll_loss(class_logprobs, targets, reduction=reduction) return loss @@ -312,6 +328,8 @@ class Mixin_CategoryTree_Torch: targets (Tensor): true class for each example Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> graph = nx.from_dict_of_lists({ >>> 'background': [], @@ -343,6 +361,7 @@ class Mixin_CategoryTree_Torch: animal -> ['animal', 'background', 'mineral'] background -> ['animal', 'background', 'mineral'] """ + import torch.nn.functional as F loss = F.nll_loss(class_logprobs, targets) return loss @@ -357,6 +376,8 @@ class Mixin_CategoryTree_Torch: detectio metrics. Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> import torch >>> from ndsampler import category_tree @@ -434,6 +455,7 @@ class Mixin_CategoryTree_Torch: def _demo_probs(self, num=5, rng=0, nonrandom=3, hackargmax=True): """ dummy probabilities for testing """ + import torch rng = kwarray.ensure_rng(rng) class_energy = torch.FloatTensor(rng.rand(num, len(self))) @@ -491,6 +513,8 @@ class Mixin_CategoryTree_Torch: pred_conf: associated confidence Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> self = CategoryTree.demo('btree', r=3, h=3) >>> rng = kwarray.ensure_rng(0) @@ -516,6 +540,8 @@ class Mixin_CategoryTree_Torch: >>> pred_cnames = list(ub.take(self.idx_to_node, pred_idxs)) Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> self = CategoryTree.demo('btree', r=3, h=3) >>> class_probs = self._demo_probs() @@ -529,6 +555,8 @@ class Mixin_CategoryTree_Torch: >>> self.decision(class_probs, dim=1, ignore_class_idxs=self.idx_groups[1]) Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> self = CategoryTree.demo('btree', r=3, h=3, add_zero=False) >>> class_probs = self._demo_probs(num=30, nonrandom=20) @@ -543,6 +571,8 @@ class Mixin_CategoryTree_Torch: >>> assert 0 not in pred_idxs1 Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> from ndsampler.category_tree import * >>> graph = nx.from_dict_of_lists({ >>> 'a': ['b', 'q'], @@ -560,6 +590,8 @@ class Mixin_CategoryTree_Torch: >>> assert np.all(pred_idxs1 == 4) Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> # FIXME: What do we do in this case? >>> # Do we always decend at level A? >>> from ndsampler.category_tree import * @@ -808,6 +840,8 @@ def gini(probs, axis=1, impl=np): Approximates Shannon Entropy, but faster to compute Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> rng = kwarray.ensure_rng(0) >>> probs = torch.softmax(torch.Tensor(rng.rand(3, 10)), 1) >>> gini(probs.numpy(), impl=kwarray.ArrayAPI.coerce('numpy')) @@ -823,6 +857,8 @@ def entropy(probs, axis=1, impl=np): Standard Shannon (Information Theory) Entropy Example: + >>> # xdoctest: +REQUIRES(module:torch) + >>> import torch >>> rng = kwarray.ensure_rng(0) >>> probs = torch.softmax(torch.Tensor(rng.rand(3, 10)), 1) >>> entropy(probs.numpy(), impl=kwarray.ArrayAPI.coerce('numpy')) diff --git a/ndsampler/coco_frames.py b/ndsampler/coco_frames.py index abe913164fdf6b8df44897e49a096ce99f4e4351..7a24b9c37c2057b0dab13a385309113fa759ad62 100644 --- a/ndsampler/coco_frames.py +++ b/ndsampler/coco_frames.py @@ -6,6 +6,56 @@ from os.path import join import ubelt as ub +if 0: + # Maybe? + import smqtk + import abc + smqtk.Pluggable = smqtk.utils.plugin.Pluggable + smqtk.Configurable = smqtk.utils.configuration.Configurable + + """ + { + "frames": { + "ndsampler.CogFrames:0": { + 'compression': 'JPEG', + }, + "ndsampler.NPYFrames": { + }, + "type": "ndsampler.CogFrames" + } + } + """ + + class AbstractFrames(smqtk.Pluggable, smqtk.Configurable): + + def __init__(self, config, foo=1): + pass + + @abc.abstractmethod + def load_region(self, spec): + pass + + @classmethod + def is_usable(cls): + return True + + class CogFrames(AbstractFrames): + + # default_config = { + # 'compression': Value( + # default='JPEG', choices=['JPEG', 'DEFLATE']), + # } + # def __init__(self, **kwargs): + # super().__init__(**kwargs) + + def load_region(self, spec): + return spec + + class NPYFrames(AbstractFrames): + def load_region(self, spec): + return spec + + class CocoFrames(abstract_frames.Frames, util.HashIdentifiable): """ wrapper around coco-style dataset to allow for getitem syntax @@ -36,6 +86,7 @@ class CocoFrames(abstract_frames.Frames, util.HashIdentifiable): workdir=workdir, backend=backend) self.dset = dset self.verbose = verbose + self._image_ids = None def _lookup_gpath(self, image_id): img = self.dset.imgs[image_id] @@ -45,9 +96,13 @@ class CocoFrames(abstract_frames.Frames, util.HashIdentifiable): gpath = img['file_name'] return gpath - @ub.memoize_property + @property def image_ids(self): - return list(self.dset.imgs.keys()) + if self._image_ids is None: + import numpy as np + # Use ndarrays to prevent copy-on-write as best as possible + self._image_ids = np.array(list(self.dset.imgs.keys())) + return self._image_ids def _make_hashid(self): _hashid = getattr(self.dset, 'hashid', None) diff --git a/ndsampler/coco_regions.py b/ndsampler/coco_regions.py index 3fed61e3a1c5a41cb579ad0732e01ae0eabd30ac..985198abd711d90d73421a3cf17e1a397dd8d3ae 100644 --- a/ndsampler/coco_regions.py +++ b/ndsampler/coco_regions.py @@ -24,7 +24,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera import itertools as it import ubelt as ub # NOQA import numpy as np -from os.path import join +from os.path import join # NOQA from ndsampler.utils import util_misc from ndsampler import isect_indexer from ndsampler import coco_dataset @@ -661,8 +661,8 @@ class CocoRegions(Targets, util_misc.HashIdentifiable, ub.NiceRepr): verbose = self.verbose if not disable and self.hashid and self.workdir: - enabled = self._enabled_caches[fname] - dpath = join(self.workdir, '_cache', fname) + enabled = self._enabled_caches.get(fname, True) + dpath = ub.ensuredir((self.workdir, '_cache', fname)) else: dpath = None enabled = False # forced disable @@ -679,7 +679,7 @@ class CocoRegions(Targets, util_misc.HashIdentifiable, ub.NiceRepr): cfgstr = ub.hash_data(extra_deps) cacher = ub.Cacher(fname, cfgstr=cfgstr, dpath=dpath, - verbose=self.verbose, protocol=2, enabled=enabled) + verbose=self.verbose, enabled=enabled) return cacher diff --git a/ndsampler/coco_sampler.py b/ndsampler/coco_sampler.py index e5e5227359d34232b579293e58c6d65ceec9c752..02e33d55c40b05e2bec0a315d6fedc37cec17928 100644 --- a/ndsampler/coco_sampler.py +++ b/ndsampler/coco_sampler.py @@ -13,13 +13,14 @@ Example: '~/.cache/kwimage/demodata/Airport.jpg'] >>> # And you want to randomly load subregions of them in O(1) time >>> import ndsampler + >>> import kwcoco >>> # First make a COCO dataset that refers to your images (and possibly annotations) >>> dataset = { >>> 'images': [{'id': i, 'file_name': fpath} for i, fpath in enumerate(image_paths)], >>> 'annotations': [], >>> 'categories': [], >>> } - >>> coco_dset = ndsampler.CocoDataset(dataset) + >>> coco_dset = kwcoco.CocoDataset(dataset) >>> print(coco_dset) >>> # Now pass the dataset to a sampler and tell it where it can store temporary files @@ -504,6 +505,21 @@ class CocoSampler(abstract_sampler.AbstractSampler, util_misc.HashIdentifiable, >>> print('sample.shape = {!r}'.format(sample['im'].shape)) sample.shape = (6, 6, 3) + Example: + >>> # Access direct annotation information + >>> import ndsampler + >>> sampler = ndsampler.CocoSampler.demo() + >>> # Sample a region that contains at least one annotation + >>> tr = {'gid': 1, 'cx': 5, 'cy': 2, 'width': 600, 'height': 600} + >>> sample = sampler.load_sample(tr) + >>> annotation_ids = sample['annots']['aids'] + >>> aid = annotation_ids[0] + >>> # Method1: Access ann dict directly via the coco index + >>> ann = sampler.dset.anns[aid] + >>> # Method2: Access ann objects via annots method + >>> dets = sampler.dset.annots(annotation_ids).detections + >>> print('dets.data = {}'.format(ub.repr2(dets.data, nl=1))) + Example: >>> from ndsampler.coco_sampler import * >>> self = CocoSampler.demo() diff --git a/ndsampler/coerce_data.py b/ndsampler/coerce_data.py index 34ca7a3c126cf1f9a5301a422d24370a882a40e6..22a45a55e5588df1a64c716efd33c01c2ac31dd6 100644 --- a/ndsampler/coerce_data.py +++ b/ndsampler/coerce_data.py @@ -12,6 +12,7 @@ def coerce_datasets(config, build_hashid=False, verbose=1): * test_dataset Example: + >>> import kwcoco >>> import ndsampler.coerce_data >>> config = {'datasets': 'special:shapes'} >>> print('config = {!r}'.format(config)) @@ -22,16 +23,16 @@ def coerce_datasets(config, build_hashid=False, verbose=1): >>> ndsampler.coerce_data.coerce_datasets(config) >>> config = { - >>> 'datasets': ndsampler.CocoDataset.demo('shapes'), + >>> 'datasets': kwcoco.CocoDataset.demo('shapes'), >>> } >>> coerce_datasets(config) >>> coerce_datasets({ - >>> 'datasets': ndsampler.CocoDataset.demo('shapes'), - >>> 'test_dataset': ndsampler.CocoDataset.demo('photos'), + >>> 'datasets': kwcoco.CocoDataset.demo('shapes'), + >>> 'test_dataset': kwcoco.CocoDataset.demo('photos'), >>> }) >>> coerce_datasets({ - >>> 'datasets': ndsampler.CocoDataset.demo('shapes'), - >>> 'test_dataset': ndsampler.CocoDataset.demo('photos'), + >>> 'datasets': kwcoco.CocoDataset.demo('shapes'), + >>> 'test_dataset': kwcoco.CocoDataset.demo('photos'), >>> }) """ # Ideally the user specifies a standard train/vali/test split @@ -43,7 +44,7 @@ def coerce_datasets(config, build_hashid=False, verbose=1): def _ensure_coco(coco): # Map a file path or an in-memory dataset to a CocoDataset - import ndsampler + import kwcoco import six from os.path import exists if coco is None: @@ -51,8 +52,10 @@ def coerce_datasets(config, build_hashid=False, verbose=1): elif isinstance(coco, six.string_types): fpath = _rectify_fpath(coco) if exists(fpath): - # print('read dataset: fpath = {!r}'.format(fpath)) - coco = ndsampler.CocoDataset(fpath) + with ub.Timer('read kwcoco dataset: fpath = {!r}'.format(fpath)): + coco = kwcoco.CocoDataset(fpath, autobuild=False) + print('building kwcoco index') + coco._build_index() else: if not coco.lower().startswith('special:'): import warnings @@ -60,10 +63,10 @@ def coerce_datasets(config, build_hashid=False, verbose=1): code = coco else: code = coco.lower()[len('special:'):] - coco = ndsampler.CocoDataset.demo(code) + coco = kwcoco.CocoDataset.demo(code) else: # print('live dataset') - assert isinstance(coco, ndsampler.CocoDataset) + assert isinstance(coco, kwcoco.CocoDataset) return coco config = config.copy() diff --git a/ndsampler/isect_indexer.py b/ndsampler/isect_indexer.py index 9f7449324c097df7ce44327712b7990f5d0e09c0..b9f3067502307329c3b2aec85eeaf5506cfe44d6 100644 --- a/ndsampler/isect_indexer.py +++ b/ndsampler/isect_indexer.py @@ -9,7 +9,7 @@ import kwarray import kwimage -class FrameIntersectionIndex(object): +class FrameIntersectionIndex(ub.NiceRepr): """ Build spatial tree for each frame so we can quickly determine if a random negative is too close to a positive. For each frame/image we built a qtree. @@ -34,6 +34,12 @@ class FrameIntersectionIndex(object): self.qtrees = None self.all_gids = None + def __nice__(self): + if self.all_gids is None: + return 'None' + else: + return len(self.all_gids) + @classmethod def from_coco(cls, dset): """ @@ -70,6 +76,9 @@ class FrameIntersectionIndex(object): @staticmethod def _build_index(dset): + """ + + """ qtrees = { img['id']: pyqtree.Index((0, 0, img['width'], img['height'])) for img in dset.dataset['images'] diff --git a/ndsampler/utils/util_gdal.py b/ndsampler/utils/util_gdal.py index dc9f7a69f029e7f91503cf72fa938272133151c8..9fff10d391d5d208808134c20aeee2da2e196e98 100644 --- a/ndsampler/utils/util_gdal.py +++ b/ndsampler/utils/util_gdal.py @@ -597,10 +597,24 @@ class LazyGDalFrameFile(ub.NiceRepr): return ds @classmethod - def demo(cls): - from ndsampler.abstract_frames import SimpleFrames - self = SimpleFrames.demo() - self = self._load_image_cog(1) + def demo(cls, key='astro', dsize=None): + """ + Ignore: + >>> from ndsampler.utils.util_gdal import * # NOQA + >>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400)) + """ + cache_dpath = ub.ensure_app_cache_dir('ndsampler/demo') + fpath = join(cache_dpath, key + '.cog.tiff') + depends = ub.odict(dsize=dsize) + cfgstr = ub.hash_data(depends) + stamp = ub.CacheStamp(fname=key, cfgstr=cfgstr, dpath=cache_dpath, + product=[fpath]) + if stamp.expired(): + import kwimage + img = kwimage.grab_test_image(key, dsize=dsize) + kwimage.imwrite(fpath, img, backend='gdal') + stamp.renew() + self = cls(fpath) return self @property @@ -629,6 +643,16 @@ class LazyGDalFrameFile(ub.NiceRepr): """ References: https://gis.stackexchange.com/questions/162095/gdal-driver-create-typeerror + + Ignore: + >>> from ndsampler.utils.util_gdal import * # NOQA + >>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400)) + >>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None] + >>> img_part = self[index] + >>> # xdoctest: +REQUIRES(--show) + >>> import kwplot + >>> kwplot.autompl() + >>> kwplot.imshow(img_part) """ ds = self._ds width = ds.RasterXSize @@ -655,24 +679,34 @@ class LazyGDalFrameFile(ub.NiceRepr): rb_indices = range(C) assert len(trailing_part) <= 1 - # TODO: preallocate like kwimage - channels = [] - for i in rb_indices: - rb = ds.GetRasterBand(1 + i) - xsize = rb.XSize - ysize = rb.YSize - - ystart, ystop = map(int, [ypart.start, ypart.stop]) - ysize = ystop - ystart - - xstart, xstop = map(int, [xpart.start, xpart.stop]) - xsize = xstop - xstart - - gdalkw = dict(xoff=xstart, yoff=ystart, win_xsize=xsize, win_ysize=ysize) - channel = rb.ReadAsArray(**gdalkw) - channels.append(channel) - - img_part = np.dstack(channels) + ystart, ystop = map(int, [ypart.start, ypart.stop]) + xstart, xstop = map(int, [xpart.start, xpart.stop]) + + ysize = ystop - ystart + xsize = xstop - xstart + + gdalkw = dict(xoff=xstart, yoff=ystart, + win_xsize=xsize, win_ysize=ysize) + + PREALLOC = 1 + if PREALLOC: + # preallocate like kwimage.im_io._imread_gdal + from kwimage.im_io import _gdal_to_numpy_dtype + shape = (ysize, xsize, len(rb_indices)) + bands = [ds.GetRasterBand(1 + rb_idx) + for rb_idx in rb_indices] + gdal_dtype = bands[0].DataType + dtype = _gdal_to_numpy_dtype(gdal_dtype) + img_part = np.empty(shape, dtype=dtype) + for out_idx, rb in enumerate(bands): + img_part[:, :, out_idx] = rb.ReadAsArray(**gdalkw) + else: + channels = [] + for rb_idx in rb_indices: + rb = ds.GetRasterBand(1 + rb_idx) + channel = rb.ReadAsArray(**gdalkw) + channels.append(channel) + img_part = np.dstack(channels) return img_part def validate(self, orig_fpath=None, orig_data=None): diff --git a/ndsampler/utils/util_sklearn.py b/ndsampler/utils/util_sklearn.py index 866bf4385fbb2c0799a1201a37efdc7138ff05ff..64cc4354527c03f2ab3ea5fee3d9c799181c66db 100644 --- a/ndsampler/utils/util_sklearn.py +++ b/ndsampler/utils/util_sklearn.py @@ -26,6 +26,8 @@ class StratifiedGroupKFold(_BaseKFold): """ def __init__(self, n_splits=3, shuffle=False, random_state=None): + if not shuffle: + random_state = None super(StratifiedGroupKFold, self).__init__(n_splits, shuffle, random_state) def _make_test_folds(self, X, y=None, groups=None): diff --git a/setup.py b/setup.py index 3f33db9e8c5a51a7723d0afd883a19568131332c..54d4f36daf339e90459abaae838c5346d900de6d 100755 --- a/setup.py +++ b/setup.py @@ -150,7 +150,6 @@ if __name__ == '__main__': # This should be interpreted as Apache License v2.0 'License :: OSI Approved :: Apache Software License', # Supported Python versions - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', ], )