"""srv6endx.py

Created by Vincent Bernat
Copyright (c) 2025 Exa Networks. All rights reserved.
"""

from __future__ import annotations

import json
from struct import unpack
from typing import Callable

from exabgp.bgp.message.notification import Notify
from exabgp.bgp.message.update.attribute.bgpls.linkstate import FlagLS, LinkState
from exabgp.protocol.ip import IPv6
from exabgp.util import hexstring

# Minimum data length for SRv6 End.X SID TLV (RFC 9514 Section 4.1)
# Endpoint Behavior (2) + Flags (1) + Algorithm (1) + Weight (1) + Reserved (1) + SID (16) = 22 bytes
SRV6_ENDX_MIN_LENGTH = 22

#    RFC 9514:  4.1. SRv6 End.X SID TLV
#  0                   1                   2                   3
#  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |               Type            |          Length               |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |        Endpoint Behavior      |      Flags    |   Algorithm   |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |     Weight    |   Reserved    |  SID (16 octets) ...          |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |    SID (cont ...)                                             |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |    SID (cont ...)                                             |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |    SID (cont ...)                                             |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |    SID (cont ...)             | Sub-TLVs (variable) . . .
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+


@LinkState.register_lsid()
class Srv6EndX(FlagLS):
    TLV = 1106
    FLAGS = ['B', 'S', 'P', 'RSV', 'RSV', 'RSV', 'RSV', 'RSV']
    MERGE = True
    registered_subsubtlvs: dict[int, type] = dict()

    def __init__(self, packed: bytes, parsed_content: dict[str, object] | None = None) -> None:
        """Initialize with packed bytes and optionally pre-parsed content.

        For Srv6EndX, content is complex (includes sub-TLVs parsed to JSON),
        so we store both packed bytes and parsed content.
        """
        self._packed = packed
        # Store parsed content in a list (for merge support)
        self._content_list: list[dict[str, object]] = [parsed_content] if parsed_content else []

    @property
    def content(self) -> list[dict[str, object]]:
        """Return the parsed content list."""
        return self._content_list

    def merge(self, other: Srv6EndX) -> None:
        """Merge another Srv6EndX's content into this one."""
        self._content_list.extend(other.content)

    def __repr__(self) -> str:
        return '\n'.join(
            [
                'behavior: {}, flags: {}, algorithm: {}, weight: {}, sid: {}'.format(
                    d.get('behavior'), d.get('flags'), d.get('algorithm'), d.get('weight'), d.get('sid')
                )
                for d in self.content
            ],
        )

    @classmethod
    def register_subsubtlv(cls) -> Callable[[type], type]:
        """Register a sub-sub-TLV class for SRv6 End.X."""

        def decorator(klass: type) -> type:
            code = klass.TLV  # type: ignore[attr-defined]
            if code in cls.registered_subsubtlvs:
                raise RuntimeError('only one class can be registered per SRv6 End.X Sub-TLV type')
            cls.registered_subsubtlvs[code] = klass
            return klass

        return decorator

    @classmethod
    def unpack_bgpls(cls, data: bytes) -> Srv6EndX:
        if len(data) < SRV6_ENDX_MIN_LENGTH:
            raise Notify(3, 5, f'SRv6 End.X SID: data too short, need {SRV6_ENDX_MIN_LENGTH} bytes, got {len(data)}')
        original_data = data
        behavior = unpack('!I', bytes([0, 0]) + data[:2])[0]
        flags = cls.unpack_flags(data[2:3])
        algorithm = data[3]
        weight = data[4]
        sid = IPv6.ntop(data[6:22])
        data = data[22:]
        subtlvs = []

        while data and len(data) >= cls.BGPLS_SUBTLV_HEADER_SIZE:
            code = unpack('!H', data[0:2])[0]
            length = unpack('!H', data[2:4])[0]

            if code in cls.registered_subsubtlvs:
                subsubtlv = cls.registered_subsubtlvs[
                    code
                ].unpack_bgpls(  # type: ignore[attr-defined]
                    data[cls.BGPLS_SUBTLV_HEADER_SIZE : length + cls.BGPLS_SUBTLV_HEADER_SIZE]
                )
                subtlvs.append(subsubtlv.json())
            else:
                # Unknown sub-TLV: format as JSON string with hex data
                hex_data = hexstring(data[cls.BGPLS_SUBTLV_HEADER_SIZE : length + cls.BGPLS_SUBTLV_HEADER_SIZE])
                subtlvs.append(f'"unknown-subtlv-{code}": "{hex_data}"')
            data = data[length + cls.BGPLS_SUBTLV_HEADER_SIZE :]

        parsed_content = {
            'flags': flags,
            'behavior': behavior,
            'algorithm': algorithm,
            'weight': weight,
            'sid': sid,
            **json.loads('{' + ', '.join(subtlvs) + '}'),
        }

        return cls(packed=original_data, parsed_content=parsed_content)

    def json(self, compact: bool = False) -> str:
        return '"srv6-endx": [ {} ]'.format(', '.join([json.dumps(d, indent=compact) for d in self.content]))
