#!/usr/bin/env python3

# ========================================================
#    This code is for study only.
#    This code is not optimized.
#    This code is over simplified.
#    This code doesn't respect rules of good coding.
#    This code has duplicate codes: 
#        to understanding without complexity.
#    Some values -useless on code- will not be converted.
# ========================================================

import sys
import io
from cryptography.hazmat.primitives.ciphers import ( Cipher, algorithms, modes )
from cryptography.hazmat.backends import default_backend

# Conversion en int
def to_int(length : bytes = b'') -> int:
    return int.from_bytes(length, byteorder='big')

if len(sys.argv) < 3:
    print("Usage: %s <mxf> <aeskey>" % sys.argv[0])
    sys.exit(1)

mxf_file = sys.argv[1]
aes_key = bytes.fromhex(sys.argv[2])

with open(mxf_file, "rb") as file:

    while True:

        # Key : Universal Label
        key = file.read(16)

        # End of file
        if not key: 
            break

        # Length (BER format)
        length = to_int(file.read(4)[1:])  # BER format - read last 3 bytes

        # Value
        value = file.read(length)

        # Filter by KLV Encrypted Essence
        # SMPTE & Interop
        if key.hex() != "060e2b34020401010d010301027e0100" and \
           key.hex() != "060e2b34020401070d010301027e0100":
           continue

        # Show each KLV
        print("{key} - {length:>6d} - {data}...".format(
            key = key.hex(),
            length = length,
            data = value[0:16].hex()
        ))

        # read Value
        value = io.BytesIO(value)

        print("CryptographicContextLink Length         : %s" % value.read(4).hex())
        print("CryptographicContextLink Value          : %s" % value.read(16).hex())
        print("PlaintextOffset Length                  : %s" % value.read(4).hex())

        # plaintextOffset
        plaintextOffsetValue_bytes = value.read(8)
        plaintextOffsetValue = to_int(plaintextOffsetValue_bytes)
        print("PlaintextOffset Value                   : %s (%s bytes)" % (
            plaintextOffsetValue_bytes.hex(),
            plaintextOffsetValue
        ))

        print("SourceKey Length                        : %s" % value.read(4).hex())
        print("SourceKey Value                         : %s" % value.read(16).hex())
        print("SourceLength Length                     : %s" % value.read(4).hex())

        # sourceLengthValue
        sourceLengthValue_bytes = value.read(8)
        sourceLengthValue = to_int(sourceLengthValue_bytes)
        print("SourceLength Value                      : %s (%s bytes)" % (
            sourceLengthValue_bytes.hex(),
            sourceLengthValue
        ))

        # encryptedSourceLength
        encryptedSourceLength_bytes = value.read(4)
        encryptedSourceLength = to_int(encryptedSourceLength_bytes[1:])  # BER format - read last 3 bytes
        print("Encrypted Source Length                 : %s (%s bytes)" % (
            encryptedSourceLength_bytes.hex(),
            encryptedSourceLength
        ))

        IV = value.read(16)
        print("Encrypted Source Value - IV             : %s" % IV.hex())

        checkValue = value.read(16)
        print("Encrypted Source Value - CheckValue     : %s" % checkValue.hex())

        plaintextData = value.read(plaintextOffsetValue)
        print("Encrypted Source Value - Plaintext Data : %s" % plaintextData.hex())

        # EncryptedData excludes plaintextData + IV + CheckValue
        encryptedDataLength = ( encryptedSourceLength - plaintextOffsetValue - 16 - 16 )
        encryptedData = value.read(encryptedDataLength)
        print("Encrypted Source Value - Encrypted Data : %s..." % encryptedData[0:16].hex())

        # TrackFile ID
        print("TrackFile ID Length                     : %s" % value.read(4).hex())
        print("TrackFile ID Value                      : %s" % value.read(16).hex())

        # Sequence Number
        print("Sequence Number Length                  : %s" % value.read(4).hex())
        print("Sequence Number Value                   : %s" % value.read(8).hex())

        # MIC
        print("Message Integrity Code (MIC) Length     : %s" % value.read(4).hex())
        print("Message Integrity Code (MIC) Value      : %s" % value.read(20).hex())


        # ---------------------------- #
        #          Decryption          #
        # ---------------------------- #

        # Set cryptographic engine
        cipher = Cipher(
            algorithms.AES(key=aes_key),
            modes.CBC(initialization_vector=IV),
            backend=default_backend()
        )
        decryptor = cipher.decryptor()

        # add CheckValue on decryption workflow
        decryptor.update(data=checkValue)

        # add PlaintextData directly to Plaintext
        plaintext = plaintextData

        # add chunk of encryptedData to Plaintext
        encryptedData = io.BytesIO(encryptedData)
        while True:
            chunk = encryptedData.read(16)
            if not chunk:
                break
            plaintext += decryptor.update(data=chunk)

        print("Plaintext Source Value                  : %d bytes" % len(plaintext))
        print("Padding                                 : %d bytes" % (len(plaintext) - sourceLengthValue))

        # write Plaintext to file
        with open("output_%d.j2c" % file.tell(), "wb") as f:
            f.write(plaintext[0:sourceLengthValue])
