#!/usr/bin/env python3

from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
import base64
from lxml import etree

PRIVATE_KEY_FILENAME = "private_key.pem"
PRIVATE_KEY_PASSWORD = None
KDM_FILENAME = "kdm.xml"

# open private key (RSA PRIVATE KEY)
with open(PRIVATE_KEY_FILENAME, "rb") as file:
    private_key_content = file.read()
    
# load private key
private_key = serialization.load_pem_private_key(
    data = private_key_content,
    password = PRIVATE_KEY_PASSWORD
)

# create padding with OAEP and SHA-1
oaep_padding = padding.OAEP(
    mgf = padding.MGF1(hashes.SHA1()),
    algorithm = hashes.SHA1(),
    label = None
)

# open and read kdm file
with open(KDM_FILENAME, "rb") as xml:
    tree = etree.fromstring(
        text = xml.read()
    )

# show all KeyId from KeyIdList
# AuthenticatedPublic > RequiredExtensions > KDMRequiredExtensions > KeyIdList > TypedKeyId[…] > KeyId
keys = tree.xpath("//*[local-name()='TypedKeyId']")

# read each public KeyId and KeyType
# this part is not encrypted
for key in keys:
    keyType = key.xpath("./*[local-name()='KeyType']/text()")[0]
    keyId = key.xpath("./*[local-name()='KeyId']/text()")[0]
    print("KeyId %s - KeyType %s" % (keyId, keyType))

# find all cipher values, each contains lot of data, including AES key
# AuthenticatedPrivate > enc:EncryptedKey[…] > enc:CipherData
cipher_values = tree.xpath("//*[local-name()='CipherValue']")

# read each private cipherValue
for cipher_value in cipher_values:

    # base64 decryption
    encrypted_value = base64.b64decode(cipher_value.text)

    # RSA decryption
    plaintext = private_key.decrypt(
        ciphertext = encrypted_value,
        padding = oaep_padding
    )
    
    # parse all plaintext datas
    # data is a static structure 
    print("* Cipher Base64          : %s" % cipher_value.text)
    print("* Cipher Text            : %s" % encrypted_value.hex())
    print("* Plaintext              : %s" % plaintext.hex())
    print("* Structure ID           : %s" % plaintext[0:0+16].hex())
    print("* Certificate ThumbPrint : %s" % plaintext[16:16+20].hex())
    print("* CPL Id                 : %s" % plaintext[36:36+16].hex())
    print("* Key Type               : %s" % plaintext[52:52+4].decode('utf-8'))
    print("* Key Id                 : %s" % plaintext[56:56+16].hex())
    print("* Date Not Valid Before  : %s" % plaintext[72:72+25].decode('utf-8'))
    print("* Date Not Valid After   : %s" % plaintext[97:97+25].decode('utf-8'))
    print("* AES Key                : %s" % plaintext[122:122+16].hex())
    print("")
