protocol.py (Source)

#!/usr/bin/env python3
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

""""Parse for a simple text-based frame format"""

import binascii
import enum
import string


class ProtocolParser:
    class ProtocolState(enum.Enum):
        SYNC = 0
        LENGTH = 1
        MESSAGE = 2
        CHECKSUM = 3
        ERROR = 4

    SYNC_CHARACTER = b'$'
    MESSAGE_DELIMITER = b':'
    FRAME_END = b'\n'

    EXPECTED_LENGTH_BYTES = 3
    EXPECTED_CHECKSUM_BYTES = 4

    ERROR_LENGTH_TOO_SHORT = "Length field is too short"
    ERROR_LENGTH_TOO_LONG = "Length field is too long"
    ERROR_INVALID_LENGTH_BYTE = "Invalid length field byte"

    ERROR_MESSAGE_TOO_SHORT = "Message field is too short"
    ERROR_MESSAGE_TOO_LONG = "Message field is too long"

    ERROR_CHECKSUM_TOO_SHORT = "Checksum field is too short"
    ERROR_CHECKSUM_TOO_LONG = "Checksum field is too long"
    ERROR_INVALID_CHECKSUM_BYTE = "Invalid checksum field byte"
    ERROR_CHECKSUM_MISMATCH = "Checksums do not match"

    def __init__(self):
        self.state = self.ProtocolState.SYNC
        self.length = 0
        self.message = bytearray()
        self.checksum = bytearray()
        self.error = ""

        self.bytes_left = 0

    def next_byte(self, byte):
        message_received = False

        if self.state == self.ProtocolState.SYNC:
            # Waiting for synchronization character

            if byte == self.SYNC_CHARACTER:
                # Expect length information
                self.state = self.ProtocolState.LENGTH
                self.bytes_left = self.EXPECTED_LENGTH_BYTES
                self.length = 0
            else:
                # Ignore non-message byte
                pass

        elif self.state == self.ProtocolState.LENGTH:
            if byte == self.MESSAGE_DELIMITER:
                if self.bytes_left > 0:
                    # Expected more length bytes
                    self.state = self.ProtocolState.ERROR
                    self.error = self.ERROR_LENGTH_TOO_SHORT
                else:
                    # Expect message field
                    self.state = self.ProtocolState.MESSAGE
                    self.bytes_left = self.length
                    self.message = bytearray()

            else:
                if self.bytes_left > 0:
                    self.bytes_left -= 1
                    if byte.isdigit():
                        self.length = self.length * 10 + int(byte)
                    else:
                        self.state = self.ProtocolState.ERROR
                        self.error = self.ERROR_INVALID_LENGTH_BYTE
                else:
                    # Expected less length bytes
                    self.state = self.ProtocolState.ERROR
                    self.error = self.ERROR_LENGTH_TOO_LONG

        elif self.state == self.ProtocolState.MESSAGE:
            if byte == self.MESSAGE_DELIMITER:
                if self.bytes_left > 0:
                    # Expected more message bytes
                    self.state = self.ProtocolState.ERROR
                    self.error = self.ERROR_MESSAGE_TOO_SHORT
                else:
                    # Expect checksum field
                    self.state = self.ProtocolState.CHECKSUM
                    self.bytes_left = self.EXPECTED_CHECKSUM_BYTES
                    self.checksum = bytearray()
            else:
                if self.bytes_left > 0:
                    self.bytes_left -= 1
                    self.message += byte
                else:
                    # Expected less message bytes
                    self.state = self.ProtocolState.ERROR
                    self.error = self.ERROR_MESSAGE_TOO_LONG

        elif self.state == self.ProtocolState.CHECKSUM:
            if byte == self.FRAME_END:
                if self.bytes_left > 0:
                    # Expected more checksum bytes
                    self.state = self.ProtocolState.ERROR
                    self.error = self.ERROR_CHECKSUM_TOO_SHORT
                else:
                    message_checksum = binascii.crc_hqx(self.message, 0xFFFF)
                    if message_checksum != int(self.checksum, base=16):
                        self.state = self.ProtocolState.ERROR
                        self.error = self.ERROR_CHECKSUM_MISMATCH
                    else:
                        # Message complete
                        message_received = True
                        # Wait for next message
                        self.state = self.ProtocolState.SYNC
            else:
                def ishexdigit(b):
                    return b in string.hexdigits.encode()

                if self.bytes_left > 0:
                    self.bytes_left -= 1
                    if ishexdigit(byte):
                        self.checksum += byte
                    else:
                        self.state = self.ProtocolState.ERROR
                        self.error = self.ERROR_INVALID_CHECKSUM_BYTE
                else:
                    # Expected less checksum bytes
                    self.state = self.ProtocolState.ERROR
                    self.error = self.ERROR_CHECKSUM_TOO_LONG

        if self.state == self.ProtocolState.ERROR:
            # Simplistic error handling:
            # Print error and wait for the next message
            print(f"A protocol error occured: {self.error}")
            self.state = self.ProtocolState.SYNC

        return message_received


def parse_data(data):
    p = ProtocolParser()
    for i in range(len(data)):
        if p.next_byte(data[i:i+1]):
            print(f"Decoded: {p.message}")
            pass


if __name__ == "__main__":
    # A series of bytes that contains message frames, some with errors.
    # Achieves 100% statement coverage.
    test_data = bytes(
        b'ab$017:message arg1 arg2:7CD8\n'
        b'cd$17:message arg1 arg2:7CD8\n'
        b'ef$0017:message arg1 arg2:7CD8\n'
        b'gh$A01:message arg1 arg2:7CD8\n'
        b'ij$016:message arg1 arg2:7CD8\n'
        b'kl$018:message arg1 arg2:7CD8\n'
        b'mn$017:message arg1 arg2:7CD9\n'
        b'op$017:message arg1 arg2:7CD\n'
        b'qr$017:message arg1 arg2:07CD8\n'
        b'st$017:message arg1 arg2:ZCD8\n'
        b'uv$017:message arg1 arg2:7CD8\n'
    )

    parse_data(test_data)