diff --git a/youtube_dl/aes.py b/youtube_dl/aes.py index 459a76457..e323b9a32 100644 --- a/youtube_dl/aes.py +++ b/youtube_dl/aes.py @@ -35,11 +35,9 @@ def aes_ctr_decrypt(data, key, counter): for i in range(block_count): counter_block = counter.next_value() block = data[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES] - block += [0] * (BLOCK_SIZE_BYTES - len(block)) cipher_counter_block = aes_encrypt(counter_block, expanded_key) decrypted_data += xor(block, cipher_counter_block) - decrypted_data = decrypted_data[:len(data)] return decrypted_data @@ -118,15 +116,31 @@ def aes_encrypt(data, expanded_key): @param {int[]} expanded_key 176/208/240-Byte expanded key @returns {int[]} 16-Byte cipher """ + precompute_rijndael_mul() rounds = len(expanded_key) // BLOCK_SIZE_BYTES - 1 - - data = xor(data, expanded_key[:BLOCK_SIZE_BYTES]) - for i in range(1, rounds + 1): - data = sub_bytes(data) - data = shift_rows(data) - if i != rounds: - data = mix_columns(data) - data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]) + # xor + data = [x ^ y for x, y in zip(data, expanded_key[:BLOCK_SIZE_BYTES])] + for _round in range(1, rounds + 1): + # sub bytes + data = [SBOX[x] for x in data] + # shift rows + data_shifted = [0] * 16 + for column in range(4): + for row in range(4): + data_shifted[column*4 + row] = data[((column + row) & 0b11) * 4 + row] + data = data_shifted + if _round != rounds: + # mix columns + for j in range(0,16,4): + column_data = data[j:j + 4] + for row in range(4): + mixed = 0 + for column in range(4): + # xor is (+) and (-) + mixed ^= rijndael_mul_precomputed[column_data[column]][MIX_COLUMN_MATRIX[row][column]] + data[j+row] = mixed + # xor + data = [x ^ y for x, y in zip(data, expanded_key[_round * BLOCK_SIZE_BYTES: (_round + 1) * BLOCK_SIZE_BYTES])] return data @@ -139,15 +153,31 @@ def aes_decrypt(data, expanded_key): @param {int[]} expanded_key 176/208/240-Byte expanded key @returns {int[]} 16-Byte state """ + precompute_rijndael_mul() rounds = len(expanded_key) // BLOCK_SIZE_BYTES - 1 - - for i in range(rounds, 0, -1): - data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]) - if i != rounds: - data = mix_columns_inv(data) - data = shift_rows_inv(data) - data = sub_bytes_inv(data) - data = xor(data, expanded_key[:BLOCK_SIZE_BYTES]) + for _round in range(rounds, 0, -1): + # xor + data = [x ^ y for x, y in zip(data, expanded_key[_round * BLOCK_SIZE_BYTES: (_round + 1) * BLOCK_SIZE_BYTES])] + if _round != rounds: + # mix columns + for j in range(0,16,4): + column_data = data[j:j + 4] + for row in range(4): + mixed = 0 + for column in range(4): + # xor is (+) and (-) + mixed ^= rijndael_mul_precomputed[column_data[column]][MIX_COLUMN_MATRIX_INV[row][column]] + data[j+row] = mixed + # shift rows inv + data_shifted = [0] * 16 + for column in range(4): + for row in range(4): + data_shifted[column*4 + row] = data[((column - row) & 0b11) * 4 + row] + data = data_shifted + # sub bytes + data = [SBOX_INV[x] for x in data] + # xor + data = [x ^ y for x, y in zip(data, expanded_key[:BLOCK_SIZE_BYTES])] return data @@ -262,6 +292,23 @@ RIJNDAEL_LOG_TABLE = (0x00, 0x00, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7 0x53, 0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab, 0x44, 0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5, 0x67, 0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07) +rijndael_mul_precomputed = None + + +def precompute_rijndael_mul(): + global rijndael_mul_precomputed + if rijndael_mul_precomputed is not None: + return + rijndael_mul_precomputed = [[0] * 256 for _ in range(256)] + for i in range(256): + for j in range(256): + rijndael_mul_precomputed[i][j] = rijndael_mul(i, j) + + +def rijndael_mul(a, b): + if(a == 0 or b == 0): + return 0 + return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF] def sub_bytes(data): @@ -288,51 +335,6 @@ def xor(data1, data2): return [x ^ y for x, y in zip(data1, data2)] -def rijndael_mul(a, b): - if(a == 0 or b == 0): - return 0 - return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF] - - -def mix_column(data, matrix): - data_mixed = [] - for row in range(4): - mixed = 0 - for column in range(4): - # xor is (+) and (-) - mixed ^= rijndael_mul(data[column], matrix[row][column]) - data_mixed.append(mixed) - return data_mixed - - -def mix_columns(data, matrix=MIX_COLUMN_MATRIX): - data_mixed = [] - for i in range(4): - column = data[i * 4: (i + 1) * 4] - data_mixed += mix_column(column, matrix) - return data_mixed - - -def mix_columns_inv(data): - return mix_columns(data, MIX_COLUMN_MATRIX_INV) - - -def shift_rows(data): - data_shifted = [] - for column in range(4): - for row in range(4): - data_shifted.append(data[((column + row) & 0b11) * 4 + row]) - return data_shifted - - -def shift_rows_inv(data): - data_shifted = [] - for column in range(4): - for row in range(4): - data_shifted.append(data[((column - row) & 0b11) * 4 + row]) - return data_shifted - - def inc(data): data = data[:] # copy for i in range(len(data) - 1, -1, -1):