#!/usr/bin/env python3

# ========================================================
#    This code is for study only.
#    This code is not optimized.
#    This code is over simplified.
#    This code doesn't respect rules of good coding.
#    This code has duplicate codes: 
#        to understanding without complexity.
#    Some values -useless on code- will not be converted.
# ========================================================

import sys
import io
import hashlib
import hmac


# Conversion en int
def to_int(length : bytes = b'') -> int:
    return int.from_bytes(length, byteorder='big')

if len(sys.argv) < 3:
    print("Usage: %s <mxf> <aeskey>" % sys.argv[0])
    sys.exit(1)

mxf_file = sys.argv[1]
aes_key = bytes.fromhex(sys.argv[2])

with open(mxf_file, "rb") as file:

    while True:

        # Key : Universal Label
        key = file.read(16)

        # End of file
        if not key: 
            break

        # Length (BER format)
        length = to_int(file.read(4)[1:])  # BER format - read last 3 bytes

        # Value
        value_bytes = file.read(length)

        # Filter by KLV Encrypted Essence
        # SMPTE & Interop
        if key.hex() != "060e2b34020401010d010301027e0100" and \
           key.hex() != "060e2b34020401070d010301027e0100":
           continue

        # Show each KLV
        print("{key} - {length:>6d} - {data}...".format(
            key = key.hex(),
            length = length,
            data = value_bytes[0:16].hex()
        ))

        # read Value
        value = io.BytesIO(value_bytes)
        value.seek(0)

        print("CryptographicContextLink Length         : %s" % value.read(4).hex())
        print("CryptographicContextLink Value          : %s" % value.read(16).hex())
        print("PlaintextOffset Length                  : %s" % value.read(4).hex())

        # plaintextOffset
        plaintextOffsetValue_bytes = value.read(8)
        plaintextOffsetValue = to_int(plaintextOffsetValue_bytes)
        print("PlaintextOffset Value                   : %s (%s bytes)" % (
            plaintextOffsetValue_bytes.hex(),
            plaintextOffsetValue
        ))

        print("SourceKey Length                        : %s" % value.read(4).hex())
        print("SourceKey Value                         : %s" % value.read(16).hex())
        print("SourceLength Length                     : %s" % value.read(4).hex())

        # sourceLengthValue
        sourceLengthValue_bytes = value.read(8)
        sourceLengthValue = to_int(sourceLengthValue_bytes)
        print("SourceLength Value                      : %s (%s bytes)" % (
            sourceLengthValue_bytes.hex(),
            sourceLengthValue
        ))

        # encryptedSourceLength
        encryptedSourceLength_bytes = value.read(4)
        encryptedSourceLength = to_int(encryptedSourceLength_bytes[1:])  # BER format - read last 3 bytes
        print("Encrypted Source Length                 : %s (%s bytes)" % (
            encryptedSourceLength_bytes.hex(),
            encryptedSourceLength
        ))

        IV = value.read(16)
        print("Encrypted Source Value - IV             : %s" % IV.hex())

        checkValue = value.read(16)
        print("Encrypted Source Value - CheckValue     : %s" % checkValue.hex())

        plaintextData = value.read(plaintextOffsetValue)
        print("Encrypted Source Value - Plaintext Data : %s" % plaintextData.hex())

        # EncryptedData excludes plaintextData + IV + CheckValue
        encryptedDataLength = ( encryptedSourceLength - plaintextOffsetValue - 16 - 16 )
        encryptedData = value.read(encryptedDataLength)
        print("Encrypted Source Value - Encrypted Data : %s...%s" % (encryptedData[0:16].hex(), encryptedData[-16:].hex()))

        # TrackFile ID
        trackfile_length = value.read(4)
        print("TrackFile ID Length                     : %s" % trackfile_length.hex())
        trackfile_value = value.read(16)
        print("TrackFile ID Value                      : %s" % trackfile_value.hex())

        # Sequence Number
        sequencenum_length = value.read(4)
        print("Sequence Number Length                  : %s" % sequencenum_length.hex())
        sequencenum_value = value.read(8)
        print("Sequence Number Value                   : %s" % sequencenum_value.hex())

        # MIC
        mic_length = value.read(4)
        print("Message Integrity Code (MIC) Length     : %s" % mic_length.hex())
        mic_value = value.read(20)
        print("Message Integrity Code (MIC) Value      : %s" % mic_value.hex())

        # Need to calculate derivation key
        # Reference : FIPS 186-2 / General Purpose Random Number Generation
        # aes_key becomes this key :
        derivation_key = b'\x55\xAC\xAD\x4D\x81\xEF\x20\xB3\x46\xF8\x0F\x4A\x2B\xF7\x4A\x28'

        # Calculate HMAC
        digester = hmac.new(
            key=derivation_key,
            msg=None,
            digestmod=hashlib.sha1
        )
        digester.update(msg=IV)
        digester.update(msg=checkValue)
        digester.update(msg=plaintextData)
        digester.update(msg=encryptedData)
        digester.update(msg=trackfile_length)
        digester.update(msg=trackfile_value)
        digester.update(msg=sequencenum_length)
        digester.update(msg=sequencenum_value)
        digester.update(msg=mic_length)
        print("Calculate MIC = %s" % digester.hexdigest())

