#!/usr/bin/env python3

# ========================================================
#    This code is for study only.
#    This code is not optimized.
#    This code is over simplified.
#    This code doesn't respect rules of good coding.
#    This code has duplicate codes: 
#        to understanding without complexity.
#    Some values -useless on code- will not be converted.
# ========================================================

import sys
import io
import wave

# Conversion en int
def to_int(length : bytes = b'') -> int:
    return int.from_bytes(length, byteorder='big')

if len(sys.argv) < 2:
    print("Usage: %s <mxf>" % sys.argv[0])
    sys.exit(1)

mxf_file = sys.argv[1]

with open(mxf_file, "rb") as file:

    while True:

        # Key : Universal Label
        key = file.read(16)

        # End of file
        if not key: 
            break

        # Length (BER format)
        length = to_int(file.read(4)[1:])  # BER format - read last 3 bytes

        # Value
        value = file.read(length)

        # Show only KLV "Wave Audio Descriptor" & "Sound Essence"
        if key.hex() != "060e2b34025301010d01010101014800" and \
           key.hex() != "060e2b34010201010d01030116010101":
           continue

        # Show filtered KLV
        print("{key} - {length:>6d} bytes - {value}...".format(
            key    = key.hex(),
            length = length,
            value  = value[0:16].hex()
        ))

        # --------------------------------------------------
        #        Read header (Wave Audio Descriptor)
        # --------------------------------------------------

        if key.hex() == "060e2b34025301010d01010101014800" :
            print("read headers")

            buffer = io.BytesIO(value)

            # read each items
            while True:

                localtag = buffer.read(2)
                if not localtag:
                    break
                item_length = to_int(buffer.read(2))
                item_value = buffer.read(item_length)

                # Sampling rate (kHz)
                if localtag.hex() == "3d03":
                    sampling_rate = to_int(item_value[0:4])
                    print("header = Sampling Rate =", sampling_rate)

                # Channel Count (1 to 6 channels)
                if localtag.hex() == "3d07":
                    channel_count = to_int(item_value)
                    print("header = Channel Count =", channel_count)

                # Quantization bits (always 24 bits)
                if localtag.hex() == "3d01":
                    quantization_bits = to_int(item_value)
                    print("header = Quantization bits =", quantization_bits)

            # Create RIFF/WAV headers
            with wave.open("output.wav", "wb") as header:
                header.setnchannels(channel_count)              # channels
                header.setframerate(sampling_rate)              # Sampling rate
                header.setsampwidth(int(quantization_bits/8))   # bit-depth (in bytes)


        # --------------------------------------------------
        #                Read each Essence
        # --------------------------------------------------

        if key.hex() == "060e2b34010201010d01030116010101":
            print("read data audio")
            with open("output.wav", 'ab') as data:
                data.write(value)
