Source code for psd_tools.psd.patterns

"""
Patterns structure.
"""

import io
import logging
from typing import Any, BinaryIO, TypeVar

from typing_extensions import Self

from attrs import define, field

from psd_tools.compression import compress, decompress
from psd_tools.constants import ColorMode, Compression
from psd_tools.psd.base import BaseElement, ListElement
from psd_tools.psd.bin_utils import (
    is_readable,
    read_fmt,
    read_length_block,
    read_pascal_string,
    read_unicode_string,
    write_bytes,
    write_fmt,
    write_length_block,
    write_pascal_string,
    write_unicode_string,
)
from psd_tools.validators import in_

logger = logging.getLogger(__name__)

T_Patterns = TypeVar("T_Patterns", bound="Patterns")
T_Pattern = TypeVar("T_Pattern", bound="Pattern")


[docs] class Patterns(ListElement): """ List of Pattern structure. See :py:class:`~psd_tools.psd.patterns.Pattern`. """ @classmethod def read(cls: type[T_Patterns], fp: BinaryIO, **kwargs: Any) -> T_Patterns: items = [] while is_readable(fp, 4): data = read_length_block(fp, padding=4) with io.BytesIO(data) as f: items.append(Pattern.read(f)) return cls(items) # type: ignore[arg-type] def write(self, fp: BinaryIO, **kwargs: Any) -> int: written = 0 for item in self: written += write_length_block(fp, item.write, padding=4) return written
[docs] @define(repr=False) class Pattern(BaseElement): """ Pattern structure. .. py:attribute:: version .. py:attribute:: image_mode See :py:class:`ColorMode` .. py:attribute:: point Size in tuple. .. py:attribute:: name `str` name of the pattern. .. py:attribute:: pattern_id ID of this pattern. .. py:attribute:: color_table Color table if the mode is INDEXED. .. py:attribute:: data See :py:class:`VirtualMemoryArrayList` """ version: int = 1 image_mode: ColorMode = field( default=ColorMode.RGB, converter=ColorMode, validator=in_(ColorMode) ) point: tuple[int, int] = field(default=(0, 0)) name: str = "" pattern_id: str = "" color_table: list[tuple[int, int, int]] = field(factory=list) data: "VirtualMemoryArrayList" = field(default=None) @classmethod def read(cls: type[T_Pattern], fp: BinaryIO, **kwargs: Any) -> T_Pattern: version = read_fmt("I", fp)[0] assert version == 1, "Invalid version %d" % (version) image_mode = ColorMode(read_fmt("I", fp)[0]) point = read_fmt("2h", fp) name = read_unicode_string(fp) pattern_id = read_pascal_string(fp, encoding="ascii", padding=1) color_table = None if image_mode == ColorMode.INDEXED: color_table = [read_fmt("3B", fp) for i in range(256)] read_fmt("4x", fp) data = VirtualMemoryArrayList.read(fp) return cls(version, image_mode, point, name, pattern_id, color_table, data) # type: ignore[arg-type] def write(self, fp: BinaryIO, **kwargs: Any) -> int: written = write_fmt(fp, "2I", self.version, self.image_mode.value) written += write_fmt(fp, "2h", *self.point) written += write_unicode_string(fp, self.name) written += write_pascal_string(fp, self.pattern_id, encoding="ascii", padding=1) if self.color_table: for row in self.color_table: written += write_fmt(fp, "3B", *row) written += write_fmt(fp, "4x") written += self.data.write(fp) return written
[docs] @define(repr=False) class VirtualMemoryArrayList(BaseElement): """ VirtualMemoryArrayList structure. Container of channels. .. py:attribute:: version .. py:attribute:: rectangle Tuple of `int` .. py:attribute:: channels List of :py:class:`VirtualMemoryArray` """ version: int = 3 rectangle: tuple[int, int, int, int] = field(default=(0, 0, 0, 0)) channels: list["VirtualMemoryArray"] = field(factory=list) @classmethod def read(cls, fp: BinaryIO, **kwargs: Any) -> Self: version = read_fmt("I", fp)[0] assert version == 3, "Invalid version %d" % (version) data = read_length_block(fp) with io.BytesIO(data) as f: rectangle = read_fmt("4I", f) num_channels = read_fmt("I", f)[0] channels = [] for _ in range(num_channels + 2): channels.append(VirtualMemoryArray.read(f)) return cls(version, rectangle, channels) def write(self, fp: BinaryIO, **kwargs: Any) -> int: written = write_fmt(fp, "I", self.version) return written + write_length_block(fp, lambda f: self._write_body(f)) def _write_body(self, fp: BinaryIO) -> int: written = write_fmt(fp, "4I", *self.rectangle) written += write_fmt(fp, "I", len(self.channels) - 2) for channel in self.channels: written += channel.write(fp) return written
[docs] @define(repr=False) class VirtualMemoryArray(BaseElement): """ VirtualMemoryArrayList structure, corresponding to each channel. .. py:attribute:: is_written .. py:attribute:: depth .. py:attribute:: rectangle .. py:attribute:: pixel_depth .. py:attribute:: compression .. py:attribute:: data """ is_written: int = 0 depth: int | None = None rectangle: tuple[int, int, int, int] | None = None pixel_depth: int | None = None compression: Compression = field( default=Compression.RAW, converter=Compression, validator=in_(Compression) ) data: bytes = b"" @classmethod def read(cls, fp: BinaryIO, **kwargs: Any) -> Self: is_written = read_fmt("I", fp)[0] if is_written == 0: return cls(is_written=is_written) length = read_fmt("I", fp)[0] if length == 0: return cls(is_written=is_written) depth = read_fmt("I", fp)[0] rectangle = read_fmt("4I", fp) pixel_depth, compression = read_fmt("HB", fp) data = fp.read(length - 23) return cls(is_written, depth, rectangle, pixel_depth, compression, data) def write(self, fp: BinaryIO, **kwargs: Any) -> int: written = write_fmt(fp, "I", self.is_written) if self.is_written == 0: return written if self.depth is None: written += write_fmt(fp, "I", 0) return written return written + write_length_block(fp, lambda f: self._write_body(f)) def _write_body(self, fp: BinaryIO) -> int: assert self.depth is not None assert self.rectangle is not None written = write_fmt(fp, "I", self.depth) written += write_fmt(fp, "4I", *self.rectangle) written += write_fmt(fp, "HB", self.pixel_depth, self.compression.value) written += write_bytes(fp, self.data) return written
[docs] def get_data(self) -> bytes | None: """Get decompressed bytes.""" if not self.is_written: return None assert self.rectangle is not None assert self.depth is not None width, height = self.rectangle[3], self.rectangle[2] return decompress( self.data, self.compression, width, height, self.depth, version=1 )
[docs] def set_data( self, size: tuple[int, int], data: bytes, depth: int, compression: int = 0 ) -> None: """Set bytes.""" self.data = compress( data, Compression(compression), size[0], size[1], depth, version=1 ) self.depth = int(depth) self.pixel_depth = int(depth) self.rectangle = (0, 0, int(size[1]), int(size[0])) self.compression = Compression(compression) self.is_written = True