diff --git a/glymur/jp2k.py b/glymur/jp2k.py index 87028e1..d460ea6 100644 --- a/glymur/jp2k.py +++ b/glymur/jp2k.py @@ -13,8 +13,10 @@ import sys # pylint: disable=E0611 if sys.hexversion >= 0x03030000: from contextlib import ExitStack + from itertools import compress, filterfalse else: from contextlib2 import ExitStack + from itertools import compress, ifilterfalse as filterfalse from collections import Counter import ctypes @@ -28,9 +30,6 @@ import warnings import numpy as np from .codestream import Codestream -from .core import SRGB, GREYSCALE -from .core import PROGRESSION_ORDER -from .core import ENUMERATED_COLORSPACE, RESTRICTED_ICC_PROFILE from . import core from .jp2box import Jp2kBox from .jp2box import JPEG2000SignatureBox, FileTypeBox, JP2HeaderBox @@ -148,8 +147,8 @@ class Jp2k(Jp2kBox): jp2h = [box for box in self.box if box.box_id == 'jp2h'][0] colrs = [box for box in jp2h.box if box.box_id == 'colr'] for colr in colrs: - if colr.method not in (ENUMERATED_COLORSPACE, - RESTRICTED_ICC_PROFILE): + if colr.method not in (core.ENUMERATED_COLORSPACE, + core.RESTRICTED_ICC_PROFILE): msg = "Color Specification box method must specify either " msg += "an enumerated colorspace or a restricted ICC " msg += "profile if the file type box brand is 'jp2 '." @@ -310,7 +309,7 @@ class Jp2k(Jp2kBox): if 'prog' in kwargs: prog = kwargs['prog'].upper() - cparams.prog_order = PROGRESSION_ORDER[prog] + cparams.prog_order = core.PROGRESSION_ORDER[prog] if 'psnr' in kwargs: cparams.tcp_numlayers = len(kwargs['psnr']) @@ -742,11 +741,11 @@ class Jp2k(Jp2kBox): width = codestream.segment[1].xsiz num_components = len(codestream.segment[1].xrsiz) if num_components < 3: - colorspace = GREYSCALE + colorspace = core.GREYSCALE else: if len(self.box) == 0: # Best guess is SRGB - colorspace = SRGB + colorspace = core.SRGB else: # Take whatever the first jp2 header / color specification # says. @@ -763,10 +762,10 @@ class Jp2k(Jp2kBox): """ Slicing protocol. """ - if isinstance(index, slice) and ( - index.start == None and + if ((isinstance(index, slice) and + (index.start == None and index.stop == None and - index.step == None): + index.step == None)) or (index is Ellipsis)): # Case of jp2[:] = data, i.e. write the entire image. # # Should have a slice object where start = stop = step = None @@ -780,31 +779,70 @@ class Jp2k(Jp2kBox): Slicing protocol. """ codestream = self.get_codestream(header_only=True) + numrows = codestream.segment[1].ysiz + numcols = codestream.segment[1].xsiz + numbands = codestream.segment[1].Csiz + if isinstance(pargs, int): # Not a very good use of this protocol, but technically legal. # This retrieves a single row. row = pargs - area = (row, 0, row + 1, codestream.segment[1].xsiz) + area = (row, 0, row + 1, numcols) return self.read(area=area).squeeze() - if isinstance(pargs, slice): - # Case of jp2[:], i.e. retrieve the entire image. - # - # Should have a slice object where start = stop = step = None + if pargs is Ellipsis: + # Case of jp2[...] return self.read() - if isinstance(pargs, tuple) and all(isinstance(x, int) for x in pargs): - # Retrieve a single pixel. - # Something like jp2[r, c] - row = pargs[0] - col = pargs[1] - area = (row, col, row + 1, col + 1) - pixel = self.read(area=area).squeeze() - - if len(pargs) == 2: - return pixel - elif len(pargs) == 3: - return pixel[pargs[2]] + if isinstance(pargs, slice): + if pargs.start is None and pargs.stop is None and pargs.step is None: + # Case of jp2[:] + return self.read() + + # Corner case of jp2[x] where x is a slice object with non-null + # members. Just augment it with an ellipsis and let the code + # below handle it. + pargs = (pargs, Ellipsis) + + if isinstance(pargs, tuple) and any(x is Ellipsis for x in pargs): + # Remove the first ellipsis we find. + rows = slice(0, numrows) + cols = slice(0, numcols) + bands = slice(0, numbands) + if pargs[0] is Ellipsis: + if len(pargs) == 2: + newindex = (rows, cols, pargs[1]) + else: + newindex = (rows, pargs[1], pargs[2]) + elif pargs[1] is Ellipsis: + if len(pargs) == 2: + newindex = (pargs[0], cols, bands) + else: + newindex = (pargs[0], cols, pargs[2]) + else: + # Assume that we don't have 4D imagery, of course. + newindex = (pargs[0], pargs[1], bands) + + # Run once again because it is possible that there's another + # Ellipsis object in the 2nd or 3rd position. + return self.__getitem__(newindex) + + if isinstance(pargs, tuple) and any(isinstance(x, int) for x in pargs): + # Replace the first such integer argument, replace it with a slice. + lst = list(pargs) + predicate = lambda x: not isinstance(x[1], int) + g = filterfalse(predicate, enumerate(pargs)) + idx = next(g)[0] + lst[idx] = slice(pargs[idx], pargs[idx] + 1) + newindex = tuple(lst) + + # Invoke array-based slicing again, as there may be additional + # integer argument remaining. + data = self.__getitem__(newindex) + + # Reduce dimensionality in the scalar dimension. + return np.squeeze(data, axis=idx) + # Assuming pargs is a tuple of slices from now on. rows = pargs[0] @@ -814,16 +852,8 @@ class Jp2k(Jp2kBox): else: bands = pargs[2] - if rows.step is None: - rows_step = 1 - else: - rows_step = rows.step - - if cols.step is None: - cols_step = 1 - else: - cols_step = cols.step - + rows_step = 1 if rows.step is None else rows.step + cols_step = 1 if cols.step is None else cols.step if rows_step != cols_step: msg = "Row and column strides must be the same." raise IndexError(msg) @@ -838,27 +868,12 @@ class Jp2k(Jp2kBox): raise IndexError(msg) rlevel = np.int(np.round(np.log2(step))) - if rows.start is None: - rows_start = 0 - else: - rows_start = rows.start - - if rows.stop is None: - rows_stop = codestream.segment[1].ysiz - else: - rows_stop = rows.stop - - if cols.start is None: - cols_start = 0 - else: - cols_start = cols.start - - if cols.stop is None: - cols_stop = codestream.segment[1].xsiz - else: - cols_stop = cols.stop - - area = (rows_start, cols_start, rows_stop, cols_stop) + area = ( + 0 if rows.start is None else rows.start, + 0 if cols.start is None else cols.start, + numrows if rows.stop is None else rows.stop, + numcols if cols.stop is None else cols.stop + ) data = self.read(area=area, rlevel=rlevel) if len(pargs) == 2: return data @@ -1475,14 +1490,14 @@ def _validate_channel_definition(jp2h, colr): raise IOError(msg) elif len(cdef_lst) == 1: cdef = jp2h.box[cdef_lst[0]] - if colr.colorspace == SRGB: + if colr.colorspace == core.SRGB: if any([chan + 1 not in cdef.association or cdef.channel_type[chan] != 0 for chan in [0, 1, 2]]): msg = "All color channels must be defined in the " msg += "channel definition box." raise IOError(msg) - elif colr.colorspace == GREYSCALE: + elif colr.colorspace == core.GREYSCALE: if 0 not in cdef.channel_type: msg = "All color channels must be defined in the " msg += "channel definition box." diff --git a/glymur/test/test_jp2k.py b/glymur/test/test_jp2k.py index 2868e84..042a681 100644 --- a/glymur/test/test_jp2k.py +++ b/glymur/test/test_jp2k.py @@ -71,6 +71,16 @@ class SliceProtocolBase(unittest.TestCase): @unittest.skipIf(os.name == "nt", "NamedTemporaryFile issue on windows") class TestSliceProtocolBaseWrite(SliceProtocolBase): + def test_write_ellipsis(self): + expected = self.j2k_data + + with tempfile.NamedTemporaryFile(suffix='.j2k') as tfile: + j = Jp2k(tfile.name, 'wb') + j[...] = self.j2k_data + actual = j.read() + + np.testing.assert_array_equal(actual, expected) + def test_basic_write(self): expected = self.j2k_data @@ -236,6 +246,42 @@ class TestSliceProtocolRead(SliceProtocolBase): expected = self.jp2.read(area=(0, 0, 202, 202), rlevel=1) np.testing.assert_array_equal(actual, expected) + def test_ellipsis_full_read(self): + actual = self.j2k[...] + expected = self.j2k_data + np.testing.assert_array_equal(actual, expected) + + def test_ellipsis_band_select(self): + actual = self.j2k[..., 0] + expected = self.j2k_data[..., 0] + np.testing.assert_array_equal(actual, expected) + + def test_ellipsis_row_select(self): + actual = self.j2k[0, ...] + expected = self.j2k_data[0, ...] + np.testing.assert_array_equal(actual, expected) + + def test_two_ellipsis_band_select(self): + actual = self.j2k[..., ..., 1] + expected = self.j2k_data[:, :, 1] + np.testing.assert_array_equal(actual, expected) + + def test_two_ellipsis_row_select(self): + actual = self.j2k[1, ..., ...] + expected = self.j2k_data[1, :, :] + np.testing.assert_array_equal(actual, expected) + + def test_two_ellipsis_and_full_slice(self): + actual = self.j2k[..., ..., :] + expected = self.j2k_data[:] + np.testing.assert_array_equal(actual, expected) + + def test_single_slice(self): + rows = slice(3, 8) + actual = self.j2k[rows] + expected = self.j2k_data[3:8, :,:] + np.testing.assert_array_equal(actual, expected) + def test_slice_protocol_2d_reduce_resolution(self): d = self.j2k[:] self.assertEqual(d.shape, (800, 480, 3))