diff --git a/glymur/jp2box.py b/glymur/jp2box.py index 311e76f..f62567d 100644 --- a/glymur/jp2box.py +++ b/glymur/jp2box.py @@ -1882,21 +1882,14 @@ class PaletteBox(Jp2kBox): fptr.write(write_buffer) bps = self.bits_per_component - if any(b != bps[0] for b in bps): + if all(b == bps[0] for b in bps): # All components are the same. Writing is straightforward. if self.bits_per_component[0] <= 8: - code = 'B' - dtype = np.uint8 + write_buffer = np.getbuffer(self.palette.astype(np.uint8)) elif self.bits_per_component[0] <= 16: - code = 'H' - dtype = np.uint16 + write_buffer = np.getbuffer(self.palette.astype(np.uint16)) elif self.bits_per_component[0] <= 32: - code = 'I' - dtype = np.uint32 - nelts = self.palette.shape[0] * self.palette.shape[1] - fmt = '>{0}{1}'.format(nelts, code) - write_buffer = struct.pack(fmt, - self.palette.astype(dtype).flatten()) + write_buffer = np.getbuffer(self.palette.astype(np.uint32)) fptr.write(write_buffer) else: # Not all the components are the same. More general, but much rarer @@ -1933,27 +1926,46 @@ class PaletteBox(Jp2kBox): # Need to determine bps and signed or not read_buffer = fptr.read(num_columns) - data = struct.unpack('>' + 'B' * num_columns, read_buffer) - bps = [((x & 0x7f) + 1) for x in data] - signed = [((x & 0x80) > 1) for x in data] + bps_signed = struct.unpack('>' + 'B' * num_columns, read_buffer) + bps = [((x & 0x7f) + 1) for x in bps_signed] + signed = [((x & 0x80) > 1) for x in bps_signed] - fmt = '>' - for bits in bps: - if bits <= 8: - fmt += 'B' - elif bits <= 16: - fmt += 'H' - elif bits <= 32: - fmt += 'I' + if all(b == bps_signed[0] for b in bps_signed): + # Ok the palette has the same datatype for all columns. We should + # be able to efficiently read it. + if bps[0] <= 8: + nbytes_per_row = num_columns + dtype = np.uint8 + elif bps[0] <= 16: + nbytes_per_row = 2 * num_columns + dtype = np.uint16 + elif bps[0] <= 32: + nbytes_per_row = 3 * num_columns + dtype = np.uint32 + + read_buffer = fptr.read(num_entries * nbytes_per_row) + palette = np.frombuffer(read_buffer, dtype=dtype) + palette = np.reshape(palette, (num_entries, num_columns)) - # Each palette component is padded out to the next largest byte. - # That means a list comprehension does this in one shot. - row_nbytes = sum([int(math.ceil(x/8.0)) for x in bps]) + else: + # General case where the columns may not be the same width. + fmt = '>' + for bits in bps: + if bits <= 8: + fmt += 'B' + elif bits <= 16: + fmt += 'H' + elif bits <= 32: + fmt += 'I' - read_buffer = fptr.read(num_entries * row_nbytes) - palette = np.zeros((num_entries, num_columns), dtype=np.int32) - for j in range(num_entries): - palette[j] = struct.unpack_from(fmt, read_buffer, + # Each palette component is padded out to the next largest byte. + # That means a list comprehension does this in one shot. + row_nbytes = sum([int(math.ceil(x/8.0)) for x in bps]) + + read_buffer = fptr.read(num_entries * row_nbytes) + palette = np.zeros((num_entries, num_columns), dtype=np.int32) + for j in range(num_entries): + palette[j] = struct.unpack_from(fmt, read_buffer, offset=j * row_nbytes) return cls(palette, bps, signed, length=length, offset=offset)