Source code for track.utils.encrypted

import socket

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat

from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding


[docs]class EncryptedSocket(socket.socket): """Socket with an encrypted layer""" def __init__(self, *args, **kwargs): raise TypeError(f"{self.__class__.__name__} does not have a public constructor.") @classmethod def _create(cls, sock, server_side=False, handshaked=False): kwargs = dict( family=sock.family, type=sock.type, proto=sock.proto, fileno=sock.fileno() ) self = cls.__new__(cls, **kwargs) super(EncryptedSocket, self).__init__(**kwargs) self.settimeout(sock.gettimeout()) sock.detach() self.cipher = None self.message_size = None self.message_received = None self.server_side = server_side # do the handshake if client if not self.server_side and not handshaked: self._handshake() return self def _handshake(self): """Open a socket to address, port and initialize the encryption layer by exchanging a key using X25519. The key is used as an AES key throughout the communication. Returns ------- return itself """ private_key = X25519PrivateKey.generate() pubkey = private_key.public_key().public_bytes( encoding=Encoding.Raw, format=PublicFormat.Raw ) # send client public key super().send(pubkey) # receive server public Key data = super().recv(32) # from public key get shared_key server_key = X25519PublicKey.from_public_bytes(data) shared_key = private_key.exchange(server_key) key = HKDF( algorithm=hashes.SHA256(), length=48, salt=None, info=b'handshake data', backend=default_backend() ).derive(shared_key) self.cipher = Cipher( algorithms.AES(key[0:32]), modes.CBC(key[32:]), backend=default_backend() ) return self
[docs] def accept(self): """Accept an incoming connection & initialize the encryption layer for that client Returns ------- returns (socket, addr) of the client """ clt, addr = super().accept() # Generate a private key server_key = X25519PrivateKey.generate() pubkey = server_key.public_key().public_bytes( encoding=Encoding.Raw, format=PublicFormat.Raw ) data = clt.recv(32) # Receive client public Key clt.sendall(pubkey) # send public key to client public_key = X25519PublicKey.from_public_bytes(data) # get Shared key shared_key = server_key.exchange(public_key) shared_key = HKDF( algorithm=hashes.SHA256(), length=48, salt=None, info=b'handshake data', backend=default_backend() ).derive(shared_key) encrypted_socket = wrap_socket(clt, False, handshaked=True) encrypted_socket.cipher = Cipher( algorithms.AES(shared_key[0:32]), modes.CBC(shared_key[32:]), backend=default_backend() ) return encrypted_socket, addr
[docs] def send(self, data: bytes, flags: int = 0) -> int: self.sendall(data, flags) return len(data)
[docs] def sendall(self, data, flags: int = 0): if isinstance(data, bytearray): data = bytes(data) encrypt = self.cipher.encryptor() padder = padding.PKCS7(128).padder() padded_bytes = padder.update(data) padded_bytes += padder.finalize() encrypted = encrypt.update(padded_bytes) encrypted += encrypt.finalize() super().sendall(encrypted, flags) return len(data)
[docs] def readsize(self): decrypt = self.cipher.decryptor() unpadder = padding.PKCS7(128).unpadder() size = super().recv(4) return size, (decrypt, unpadder)
[docs] def recv(self, buffersize, flags: int = 0, context=None): data = super().recv(buffersize, flags) # no data nothing to decrypt if not data: return data # ---- if context is None: decrypt = self.cipher.decryptor() unpadder = padding.PKCS7(128).unpadder() else: decrypt, unpadder = context # ---- decrypted = decrypt.update(data) while True: try: decrypted += decrypt.finalize() break except ValueError: data = super().recv(buffersize, flags) decrypted += decrypt.update(data) unpadded = unpadder.update(decrypted) unpadded += unpadder.finalize() return unpadded
[docs]def wrap_socket(sock, server_side=False, handshaked=False): return EncryptedSocket._create( sock=sock, server_side=server_side, handshaked=handshaked )