|
| 1 | +# partition_writer module for MicroPython on ESP32 |
| 2 | +# MIT license; Copyright (c) 2023 Glenn Moloney @glenn20 |
| 3 | + |
| 4 | +# Based on OTA class by Thorsten von Eicken (@tve): |
| 5 | +# https://github.com/tve/mqboard/blob/master/mqrepl/mqrepl.py |
| 6 | + |
| 7 | +import hashlib |
| 8 | +import io |
| 9 | + |
| 10 | +from micropython import const |
| 11 | + |
| 12 | +IOCTL_BLOCK_COUNT: int = const(4) # type: ignore |
| 13 | +IOCTL_BLOCK_SIZE: int = const(5) # type: ignore |
| 14 | +IOCTL_BLOCK_ERASE: int = const(6) # type: ignore |
| 15 | + |
| 16 | + |
| 17 | +# An IOBase compatible class to wrap access to an os.AbstractBlockdev() device |
| 18 | +# such as a partition on the device flash. Writes must be aligned to block |
| 19 | +# boundaries. |
| 20 | +# https://docs.micropython.org/en/latest/library/os.html#block-device-interface |
| 21 | +# Extend IOBase so we can wrap this with io.BufferedWriter in BlockdevWriter |
| 22 | +class Blockdev(io.IOBase): |
| 23 | + def __init__(self, device): |
| 24 | + self.device = device |
| 25 | + self.blocksize = int(device.ioctl(IOCTL_BLOCK_SIZE, None)) |
| 26 | + self.blockcount = int(device.ioctl(IOCTL_BLOCK_COUNT, None)) |
| 27 | + self.pos = 0 # Current position (bytes from beginning) of device |
| 28 | + self.end = 0 # Current end of the data written to the device |
| 29 | + |
| 30 | + # Data must be a multiple of blocksize unless it is the last write to the |
| 31 | + # device. The next write after a partial block will raise ValueError. |
| 32 | + def write(self, data: bytes | bytearray | memoryview) -> int: |
| 33 | + block, remainder = divmod(self.pos, self.blocksize) |
| 34 | + if remainder: |
| 35 | + raise ValueError(f"Block {block} write not aligned at block boundary.") |
| 36 | + data_len = len(data) |
| 37 | + nblocks, remainder = divmod(data_len, self.blocksize) |
| 38 | + mv = memoryview(data) |
| 39 | + if nblocks: # Write whole blocks |
| 40 | + self.device.writeblocks(block, mv[: nblocks * self.blocksize]) |
| 41 | + block += nblocks |
| 42 | + if remainder: # Write left over data as a partial block |
| 43 | + self.device.ioctl(IOCTL_BLOCK_ERASE, block) # Erase block first |
| 44 | + self.device.writeblocks(block, mv[-remainder:], 0) |
| 45 | + self.pos += data_len |
| 46 | + self.end = self.pos # The "end" of the data written to the device |
| 47 | + return data_len |
| 48 | + |
| 49 | + # Read data from the block device. |
| 50 | + def readinto(self, data: bytearray | memoryview): |
| 51 | + size = min(len(data), self.end - self.pos) |
| 52 | + block, remainder = divmod(self.pos, self.blocksize) |
| 53 | + self.device.readblocks(block, memoryview(data)[:size], remainder) |
| 54 | + self.pos += size |
| 55 | + return size |
| 56 | + |
| 57 | + # Set the current file position for reading or writing |
| 58 | + def seek(self, offset: int, whence: int = 0): |
| 59 | + start = [0, self.pos, self.end] |
| 60 | + self.pos = start[whence] + offset |
| 61 | + |
| 62 | + |
| 63 | +# Calculate the SHA256 sum of a file (has a readinto() method) |
| 64 | +def sha_file(f, buffersize=4096) -> str: |
| 65 | + mv = memoryview(bytearray(buffersize)) |
| 66 | + read_sha = hashlib.sha256() |
| 67 | + while (n := f.readinto(mv)) > 0: |
| 68 | + read_sha.update(mv[:n]) |
| 69 | + return read_sha.digest().hex() |
| 70 | + |
| 71 | + |
| 72 | +# BlockdevWriter provides a convenient interface to writing images to any block |
| 73 | +# device which implements the micropython os.AbstractBlockDev interface (eg. |
| 74 | +# Partition on flash storage on ESP32). |
| 75 | +# https://docs.micropython.org/en/latest/library/os.html#block-device-interface |
| 76 | +# https://docs.micropython.org/en/latest/library/esp32.html#flash-partitions |
| 77 | +class BlockDevWriter: |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + device, # Block device to recieve the data (eg. esp32.Partition) |
| 81 | + verify: bool = True, # Should we read back and verify data after writing |
| 82 | + verbose: bool = True, |
| 83 | + ): |
| 84 | + self.device = Blockdev(device) |
| 85 | + self.writer = io.BufferedWriter( |
| 86 | + self.device, self.device.blocksize # type: ignore |
| 87 | + ) |
| 88 | + self._sha = hashlib.sha256() |
| 89 | + self.verify = verify |
| 90 | + self.verbose = verbose |
| 91 | + self.sha: str = "" |
| 92 | + self.length: int = 0 |
| 93 | + blocksize, blockcount = self.device.blocksize, self.device.blockcount |
| 94 | + if self.verbose: |
| 95 | + print(f"Device capacity: {blockcount} x {blocksize} byte blocks.") |
| 96 | + |
| 97 | + def set_sha_length(self, sha: str, length: int): |
| 98 | + self.sha = sha |
| 99 | + self.length = length |
| 100 | + blocksize, blockcount = self.device.blocksize, self.device.blockcount |
| 101 | + if length > blocksize * blockcount: |
| 102 | + raise ValueError(f"length ({length} bytes) is > size of partition.") |
| 103 | + if self.verbose and length: |
| 104 | + blocks, remainder = divmod(length, blocksize) |
| 105 | + print(f"Writing {blocks} blocks + {remainder} bytes.") |
| 106 | + |
| 107 | + def print_progress(self): |
| 108 | + if self.verbose: |
| 109 | + block, remainder = divmod(self.device.pos, self.device.blocksize) |
| 110 | + print(f"\rBLOCK {block}", end="") |
| 111 | + if remainder: |
| 112 | + print(f" + {remainder} bytes") |
| 113 | + |
| 114 | + # Append data to the block device |
| 115 | + def write(self, data: bytearray | bytes | memoryview) -> int: |
| 116 | + self._sha.update(data) |
| 117 | + n = self.writer.write(data) |
| 118 | + self.print_progress() |
| 119 | + return n |
| 120 | + |
| 121 | + # Append data from f (a stream object) to the block device |
| 122 | + def write_from_stream(self, f: io.BufferedReader) -> int: |
| 123 | + mv = memoryview(bytearray(self.device.blocksize)) |
| 124 | + tot = 0 |
| 125 | + while (n := f.readinto(mv)) != 0: |
| 126 | + tot += self.write(mv[:n]) |
| 127 | + return tot |
| 128 | + |
| 129 | + # Flush remaining data to the block device and confirm all checksums |
| 130 | + # Raises: |
| 131 | + # ValueError("SHA mismatch...") if SHA of received data != expected sha |
| 132 | + # ValueError("SHA verify fail...") if verified SHA != written sha |
| 133 | + def close(self) -> None: |
| 134 | + self.writer.flush() |
| 135 | + self.print_progress() |
| 136 | + # Check the checksums (SHA256) |
| 137 | + nbytes: int = self.device.end |
| 138 | + if self.length and self.length != nbytes: |
| 139 | + raise ValueError(f"Received {nbytes} bytes (expect {self.length}).") |
| 140 | + write_sha = self._sha.digest().hex() |
| 141 | + if not self.sha: |
| 142 | + self.sha = write_sha |
| 143 | + if self.sha != write_sha: |
| 144 | + raise ValueError(f"SHA mismatch recv={write_sha} expect={self.sha}.") |
| 145 | + if self.verify: |
| 146 | + if self.verbose: |
| 147 | + print("Verifying SHA of the written data...", end="") |
| 148 | + self.device.seek(0) # Reset to start of partition |
| 149 | + read_sha = sha_file(self.device, self.device.blocksize) |
| 150 | + if read_sha != write_sha: |
| 151 | + raise ValueError(f"SHA verify failed write={write_sha} read={read_sha}") |
| 152 | + if self.verbose: |
| 153 | + print("Passed.") |
| 154 | + if self.verbose or not self.sha: |
| 155 | + print(f"SHA256={self.sha}") |
| 156 | + self.device.seek(0) # Reset to start of partition |
| 157 | + |
| 158 | + def __enter__(self): |
| 159 | + return self |
| 160 | + |
| 161 | + def __exit__(self, e_t, e_v, e_tr): |
| 162 | + if e_t is None: |
| 163 | + self.close() |
0 commit comments