[aes.py] Optimize

This commit is contained in:
rzhxeo 2015-06-03 16:42:44 +02:00
parent b4b1d4be3e
commit 2d129b7d4b

View File

@ -35,11 +35,9 @@ def aes_ctr_decrypt(data, key, counter):
for i in range(block_count):
counter_block = counter.next_value()
block = data[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]
block += [0] * (BLOCK_SIZE_BYTES - len(block))
cipher_counter_block = aes_encrypt(counter_block, expanded_key)
decrypted_data += xor(block, cipher_counter_block)
decrypted_data = decrypted_data[:len(data)]
return decrypted_data
@ -118,15 +116,31 @@ def aes_encrypt(data, expanded_key):
@param {int[]} expanded_key 176/208/240-Byte expanded key
@returns {int[]} 16-Byte cipher
"""
precompute_rijndael_mul()
rounds = len(expanded_key) // BLOCK_SIZE_BYTES - 1
data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
for i in range(1, rounds + 1):
data = sub_bytes(data)
data = shift_rows(data)
if i != rounds:
data = mix_columns(data)
data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
# xor
data = [x ^ y for x, y in zip(data, expanded_key[:BLOCK_SIZE_BYTES])]
for _round in range(1, rounds + 1):
# sub bytes
data = [SBOX[x] for x in data]
# shift rows
data_shifted = [0] * 16
for column in range(4):
for row in range(4):
data_shifted[column*4 + row] = data[((column + row) & 0b11) * 4 + row]
data = data_shifted
if _round != rounds:
# mix columns
for j in range(0,16,4):
column_data = data[j:j + 4]
for row in range(4):
mixed = 0
for column in range(4):
# xor is (+) and (-)
mixed ^= rijndael_mul_precomputed[column_data[column]][MIX_COLUMN_MATRIX[row][column]]
data[j+row] = mixed
# xor
data = [x ^ y for x, y in zip(data, expanded_key[_round * BLOCK_SIZE_BYTES: (_round + 1) * BLOCK_SIZE_BYTES])]
return data
@ -139,15 +153,31 @@ def aes_decrypt(data, expanded_key):
@param {int[]} expanded_key 176/208/240-Byte expanded key
@returns {int[]} 16-Byte state
"""
precompute_rijndael_mul()
rounds = len(expanded_key) // BLOCK_SIZE_BYTES - 1
for i in range(rounds, 0, -1):
data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
if i != rounds:
data = mix_columns_inv(data)
data = shift_rows_inv(data)
data = sub_bytes_inv(data)
data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
for _round in range(rounds, 0, -1):
# xor
data = [x ^ y for x, y in zip(data, expanded_key[_round * BLOCK_SIZE_BYTES: (_round + 1) * BLOCK_SIZE_BYTES])]
if _round != rounds:
# mix columns
for j in range(0,16,4):
column_data = data[j:j + 4]
for row in range(4):
mixed = 0
for column in range(4):
# xor is (+) and (-)
mixed ^= rijndael_mul_precomputed[column_data[column]][MIX_COLUMN_MATRIX_INV[row][column]]
data[j+row] = mixed
# shift rows inv
data_shifted = [0] * 16
for column in range(4):
for row in range(4):
data_shifted[column*4 + row] = data[((column - row) & 0b11) * 4 + row]
data = data_shifted
# sub bytes
data = [SBOX_INV[x] for x in data]
# xor
data = [x ^ y for x, y in zip(data, expanded_key[:BLOCK_SIZE_BYTES])]
return data
@ -262,6 +292,23 @@ RIJNDAEL_LOG_TABLE = (0x00, 0x00, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7
0x53, 0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab,
0x44, 0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5,
0x67, 0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07)
rijndael_mul_precomputed = None
def precompute_rijndael_mul():
global rijndael_mul_precomputed
if rijndael_mul_precomputed is not None:
return
rijndael_mul_precomputed = [[0] * 256 for _ in range(256)]
for i in range(256):
for j in range(256):
rijndael_mul_precomputed[i][j] = rijndael_mul(i, j)
def rijndael_mul(a, b):
if(a == 0 or b == 0):
return 0
return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF]
def sub_bytes(data):
@ -288,51 +335,6 @@ def xor(data1, data2):
return [x ^ y for x, y in zip(data1, data2)]
def rijndael_mul(a, b):
if(a == 0 or b == 0):
return 0
return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF]
def mix_column(data, matrix):
data_mixed = []
for row in range(4):
mixed = 0
for column in range(4):
# xor is (+) and (-)
mixed ^= rijndael_mul(data[column], matrix[row][column])
data_mixed.append(mixed)
return data_mixed
def mix_columns(data, matrix=MIX_COLUMN_MATRIX):
data_mixed = []
for i in range(4):
column = data[i * 4: (i + 1) * 4]
data_mixed += mix_column(column, matrix)
return data_mixed
def mix_columns_inv(data):
return mix_columns(data, MIX_COLUMN_MATRIX_INV)
def shift_rows(data):
data_shifted = []
for column in range(4):
for row in range(4):
data_shifted.append(data[((column + row) & 0b11) * 4 + row])
return data_shifted
def shift_rows_inv(data):
data_shifted = []
for column in range(4):
for row in range(4):
data_shifted.append(data[((column - row) & 0b11) * 4 + row])
return data_shifted
def inc(data):
data = data[:] # copy
for i in range(len(data) - 1, -1, -1):