import hashlib
from unittest import mock

import pytest

import zigpy.ota.image as firmware
import zigpy.types as t

from .conftest import FILES_DIR

MANUFACTURER_ID = mock.sentinel.manufacturer_id
IMAGE_TYPE = mock.sentinel.image_type


@pytest.fixture
def image():
    img = firmware.OTAImage()
    img.header = firmware.OTAImageHeader(
        upgrade_file_id=firmware.OTAImageHeader.MAGIC_VALUE,
        header_version=256,
        header_length=56,
        field_control=0,
        manufacturer_id=9876,
        image_type=123,
        file_version=12345,
        stack_version=2,
        header_string="This is a test header!",
        image_size=56 + 2 + 4 + 4,
    )
    img.subelements = [firmware.SubElement(tag_id=0x0000, data=b"data")]

    return img


def test_image_serialization_bad_length(image):
    assert image.serialize()
    image.header.image_size += 1

    with pytest.raises(ValueError):
        image.serialize()

    image.header.image_size -= 1
    assert image.serialize()

    image.header.image_size -= 1
    with pytest.raises(ValueError):
        image.serialize()


def test_hw_version():
    hw = firmware.HWVersion(0x0A01)
    assert hw.version == 10
    assert hw.revision == 1

    assert "version=10" in repr(hw)
    assert "revision=1" in repr(hw)


def _test_ota_img_header(field_control, hdr_suffix=b"", extra=b""):
    d = b"\x1e\xf1\xee\x0b\x00\x018\x00"
    d += field_control
    d += (
        b"|\x11\x01!rE!\x12\x02\x00EBL tradfri_light_basic\x00\x00\x00"
        b"\x00\x00\x00\x00\x00\x00~\x91\x02\x00"
    )
    d += hdr_suffix

    hdr, rest = firmware.OTAImageHeader.deserialize(d + extra)
    assert hdr.header_version == 0x0100
    assert hdr.header_length == 0x0038
    assert hdr.manufacturer_id == 4476
    assert hdr.image_type == 0x2101
    assert hdr.file_version == 0x12214572
    assert hdr.stack_version == 0x0002
    assert hdr.image_size == 0x0002917E
    assert hdr.serialize() == d

    return hdr, rest


def test_ota_image_header():
    hdr = firmware.OTAImageHeader()
    assert hdr.security_credential_version_present is None
    assert hdr.device_specific_file is None
    assert hdr.hardware_versions_present is None

    extra = b"abcdefghklmnpqr"

    hdr, rest = _test_ota_img_header(b"\x00\x00", extra=extra)
    assert rest == extra
    assert hdr.security_credential_version_present is False
    assert hdr.device_specific_file is False
    assert hdr.hardware_versions_present is False


def test_ota_image_header_security():
    extra = b"abcdefghklmnpqr"
    creds = t.uint8_t(0xAC)
    hdr, rest = _test_ota_img_header(b"\x01\x00", creds.serialize(), extra)

    assert rest == extra
    assert hdr.security_credential_version_present is True
    assert hdr.security_credential_version == creds
    assert hdr.device_specific_file is False
    assert hdr.hardware_versions_present is False


def test_ota_image_header_hardware_versions():
    extra = b"abcdefghklmnpqr"
    hw_min = firmware.HWVersion(0xBEEF)
    hw_max = firmware.HWVersion(0xABCD)
    hdr, rest = _test_ota_img_header(
        b"\x04\x00", hw_min.serialize() + hw_max.serialize(), extra
    )

    assert rest == extra
    assert hdr.security_credential_version_present is False
    assert hdr.device_specific_file is False
    assert hdr.hardware_versions_present is True
    assert hdr.minimum_hardware_version == hw_min
    assert hdr.maximum_hardware_version == hw_max


def test_ota_image_destination():
    extra = b"abcdefghklmnpqr"

    dst = t.EUI64.deserialize(b"12345678")[0]

    hdr, rest = _test_ota_img_header(b"\x02\x00", dst.serialize(), extra)
    assert rest == extra
    assert hdr.security_credential_version_present is False
    assert hdr.device_specific_file is True
    assert hdr.upgrade_file_destination == dst
    assert hdr.hardware_versions_present is False


def test_ota_img_wrong_header():
    d = b"\x1e\xf0\xee\x0b\x00\x018\x00\x00\x00"
    d += (
        b"|\x11\x01!rE!\x12\x02\x00EBL tradfri_light_basic\x00\x00\x00"
        b"\x00\x00\x00\x00\x00\x00~\x91\x02\x00"
    )

    with pytest.raises(ValueError):
        firmware.OTAImageHeader.deserialize(d)

    with pytest.raises(ValueError):
        firmware.OTAImageHeader.deserialize(d + b"123abc")


def test_header_string():
    size = 32
    header_string = "This is a header String"
    data = header_string.encode("utf8").ljust(size, b"\x00")
    extra = b"cdef123"

    hdr_str, rest = firmware.HeaderString.deserialize(data + extra)
    assert rest == extra

    with pytest.raises(ValueError):
        firmware.HeaderString(b"foo")

    with pytest.raises(ValueError):
        firmware.HeaderString(b"a" * 33)

    hdr_str, rest = firmware.HeaderString.deserialize(data)
    assert rest == b""
    assert header_string in str(hdr_str)
    assert firmware.HeaderString(header_string).serialize() == data


def test_header_string_roundtrip_invalid():
    data = bytes.fromhex(
        "5a757d364000603e400013704000010000009f364000b015400020904000ffff"
    )

    hdr_str, rest = firmware.HeaderString.deserialize(data)
    assert not rest
    assert hdr_str == firmware.HeaderString(data)

    assert hdr_str.serialize() == data
    assert data.hex() in str(hdr_str)


def test_header_string_too_short():
    header_string = "This is a header String"
    data = header_string.encode("utf8")

    with pytest.raises(ValueError):
        firmware.HeaderString.deserialize(data)


def test_subelement():
    payload = b"\x00payload\xff"
    data = b"\x01\x00" + t.uint32_t(len(payload)).serialize() + payload
    extra = b"extra"

    e, rest = firmware.SubElement.deserialize(data + extra)
    assert rest == extra
    assert e.tag_id == firmware.ElementTagId.ECDSA_SIGNATURE_CRYPTO_SUITE_1
    assert e.data == payload
    assert len(e.data) == len(payload)

    assert e.serialize() == data


def test_subelement_too_short():
    for i in range(1, 5):
        with pytest.raises(ValueError):
            firmware.SubElement.deserialize(b"".ljust(i, b"\x00"))

    e, rest = firmware.SubElement.deserialize(b"\x00\x00\x00\x00\x00\x00")
    assert e.data == b""
    assert rest == b""

    with pytest.raises(ValueError):
        firmware.SubElement.deserialize(b"\x00\x02\x02\x00\x00\x00a")


def test_subelement_repr():
    sub1 = firmware.SubElement(
        tag_id=firmware.ElementTagId.UPGRADE_IMAGE, data=b"\x00" * 32
    )
    assert (
        "32:0000000000000000000000000000000000000000000000000000000000000000"
        in repr(sub1)
    )

    sub2 = firmware.SubElement(
        tag_id=firmware.ElementTagId.UPGRADE_IMAGE, data=b"\x00" * 33
    )
    assert (
        "33:00000000000000000000000000000000000000000000000000...00000000000000"
        in repr(sub2)
    )


@pytest.fixture
def raw_header():
    def data(elements_size=0):
        d = b"\x1e\xf1\xee\x0b\x00\x018\x00\x00\x00"
        d += b"|\x11\x01!rE!\x12\x02\x00EBL tradfri_light_basic\x00\x00\x00"
        d += b"\x00\x00\x00\x00\x00\x00"
        d += t.uint32_t(elements_size + 56).serialize()
        return d

    return data


@pytest.fixture
def raw_sub_element():
    def data(tag_id, payload=b""):
        r = t.uint16_t(tag_id).serialize()
        r += t.uint32_t(len(payload)).serialize()
        return r + payload

    return data


def test_ota_image(raw_header, raw_sub_element):
    el1_payload = b"abcd"
    el2_payload = b"4321"
    el1 = raw_sub_element(0, el1_payload)
    el2 = raw_sub_element(1, el2_payload)

    extra = b"edbc321"
    img, rest = firmware.OTAImage.deserialize(
        raw_header(len(el1 + el2)) + el1 + el2 + extra
    )

    assert rest == extra
    assert len(img.subelements) == 2
    assert img.subelements[0].tag_id == 0
    assert img.subelements[0].data == el1_payload
    assert img.subelements[1].tag_id == 1
    assert img.subelements[1].data == el2_payload

    assert img.serialize() == raw_header(len(el1 + el2)) + el1 + el2

    with pytest.raises(ValueError):
        firmware.OTAImage.deserialize(raw_header(len(el1 + el2)) + el1 + el2[:-1])


def wrap_ikea(data):
    header = bytearray(100)
    header[0:4] = b"NGIS"
    header[16:20] = len(header).to_bytes(4, "little")
    header[20:24] = len(data).to_bytes(4, "little")

    return header + data + b"F" * 512


def test_parse_ota_normal(image):
    assert firmware.parse_ota_image(image.serialize()) == (image, b"")


def test_parse_ota_ikea(image):
    data = wrap_ikea(image.serialize())
    assert firmware.parse_ota_image(data) == (image, b"")


def test_parse_ota_ikea_trailing(image):
    data = wrap_ikea(image.serialize() + b"trailing")

    parsed, remaining = firmware.parse_ota_image(data)
    assert not remaining

    assert parsed.header.image_size == len(image.serialize() + b"trailing")
    assert parsed.subelements[0].data == b"data" + b"trailing"

    parsed2, remaining2 = firmware.OTAImage.deserialize(parsed.serialize())
    assert not remaining2


@pytest.mark.parametrize(
    "data",
    [
        b"NGIS" + b"truncated",
        b"NGIS" + b"long enough to container header but not actual image",
    ],
)
def test_parse_ota_ikea_truncated(data):
    with pytest.raises(ValueError):
        firmware.parse_ota_image(data)


def create_hue_ota(data):
    data = b"\x2a\x00\x01" + data

    header, _ = firmware.OTAImageHeader.deserialize(
        bytes.fromhex(
            "1ef1ee0b0001380000000b100301d5670042020000000000000000000000000000000000000000"
            "0000000000000000000000000038f00300"
        )
    )
    header.image_size = len(header.serialize()) + len(data)

    return header.serialize() + data


def test_parse_ota_hue():
    data = create_hue_ota(b"test") + b"rest"
    img, rest = firmware.parse_ota_image(data)

    assert isinstance(img, firmware.HueSBLOTAImage)
    assert rest == b"rest"
    assert img.data == b"\x2a\x00\x01" + b"test"
    assert img.serialize() + b"rest" == data


def test_parse_ota_hue_invalid():
    data = create_hue_ota(b"test")
    firmware.parse_ota_image(data)

    with pytest.raises(ValueError):
        firmware.parse_ota_image(data[:-1])

    header, rest = firmware.OTAImageHeader.deserialize(data)
    assert data == header.serialize() + rest

    with pytest.raises(ValueError):
        # Three byte sequence must be the first thing after the header
        firmware.parse_ota_image(header.serialize() + b"\xff" + rest[1:])

    with pytest.raises(ValueError):
        # Only Hue is known to use these images
        firmware.parse_ota_image(header.replace(manufacturer_id=12).serialize() + rest)


def test_legrand_container_unwrapping(image):
    # Unwrapped size prefix and 1 + 16 byte suffix
    data = (
        t.uint32_t(len(image.serialize())).serialize()
        + image.serialize()
        + b"\x01"
        + b"abcdabcdabcdabcd"
    )

    with pytest.raises(ValueError):
        firmware.parse_ota_image(data[:-1])

    with pytest.raises(ValueError):
        firmware.parse_ota_image(b"\xff" + data[1:])

    img, rest = firmware.parse_ota_image(data)
    assert not rest
    assert img == image


def test_thirdreality_container(image):
    image_bytes = image.serialize()

    # There's little useful information in the header
    subcontainer = (
        t.uint32_t(16).serialize()
        # Total length of image, excluding SHA512 prefix
        + t.uint32_t(len(image_bytes) + 152 - 64).serialize()
        + t.uint32_t(152).serialize()
        # Unknown four byte prefix/suffix and what looks like a second SHA512 hash
        + b"?" * (64 + 4)
        + t.uint32_t(0).serialize()
        + t.uint32_t(0).serialize()
        + image_bytes
    )

    data = hashlib.sha512(subcontainer).digest() + subcontainer

    assert data.index(image_bytes) == 152

    img, rest = firmware.parse_ota_image(data)
    assert not rest
    assert img == image

    with pytest.raises(ValueError):
        firmware.parse_ota_image(data[:-1])

    with pytest.raises(ValueError):
        firmware.parse_ota_image(b"\xff" + data[1:])


def test_encrypted_telink_container() -> None:
    data = (FILES_DIR / "external/dl/sonoff/snzb-01m_v1.0.5.ota").read_bytes()
    img, rest = firmware.parse_ota_image(data)

    assert isinstance(img, firmware.TelinkOTAImage)
    assert not rest

    assert img.serialize() == data


def test_telink_encrypted_subelement_repr() -> None:
    element = firmware.TelinkEncryptedSubElement(
        tag_id=0xF000,
        tag_info=0x1234,
        data=b"\xab" * 40,
    )

    assert repr(element) == (
        "<TelinkEncryptedSubElement(tag_id=<ElementTagId.undefined_0xf000: 61440>"
        ", tag_info=4660, data=[40:ababababababababababababababababababababababababab"
        "...ababababababab])>"
    )


def test_telink_encrypted_subelement_deserialize_errors() -> None:
    with pytest.raises(ValueError, match="Data too short to contain encrypted"):
        firmware.TelinkEncryptedSubElement.deserialize(b"\x00" * 7)

    bad_tag = (
        t.uint16_t(0x0000).serialize()  # tag_id should be 0xF000
        + t.uint32_t(10).serialize()
        + b"\x00" * 12
    )
    with pytest.raises(ValueError, match="Not a Telink encrypted subelement"):
        firmware.TelinkEncryptedSubElement.deserialize(bad_tag)

    # Data shorter than tag_length
    short_data = (
        t.uint16_t(0xF000).serialize()
        + t.uint32_t(100).serialize()
        + t.uint16_t(0).serialize()
        + b"\x00" * 10
    )
    with pytest.raises(ValueError, match="Data too short to contain Telink subelement"):
        firmware.TelinkEncryptedSubElement.deserialize(short_data)


def test_telink_ota_image_serialize_bad_length() -> None:
    data = (FILES_DIR / "external/dl/sonoff/snzb-01m_v1.0.5.ota").read_bytes()
    img, _ = firmware.parse_ota_image(data)

    # Corrupt the header image_size
    img.header.image_size += 1

    with pytest.raises(ValueError, match="does not match actual image size"):
        img.serialize()
