Better AES

June 25, 2025

السلام عليكم ورحمة الله وبركاته

Today our challenge test the understanding of the AES implementation and how one simple change can destroy the AES security.

the challenge code as shown implements AES:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from secrets import token_bytes

class BetterAES:
    BLOCK_SIZE = 16          # Block size in bytes
    KEY_SIZE = 32            # Key size in bytes (256 bits)
    NUM_ROUNDS = 14          # Number of rounds for AES-256

    # Round constants
    RCON = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36]

    # for easier testing
    sbox= list(range(0, 256))

    def __init__(self, key: bytes):
        self.aes_256_key = key

        # inverse S-box
        self.inv_sbox = [0] * 256
        for i, v in enumerate(self.sbox):
            self.inv_sbox[v] = i

    def gf_mult(self, a, b):
        result = 0
        for _ in range(8):
            if b & 1:
                result ^= a
            high = a & 0x80
            a = (a << 1) & 0xFF
            if high:
                a ^= 0x1B
            b >>= 1
        return result

    def sub_bytes(self, state):
        return [self.sbox[b] for b in state]

    def inv_sub_bytes(self, state):
        return [self.inv_sbox[b] for b in state]

    def shift_rows(self, state):
        new = [0] * 16
        for r in range(4):
            for c in range(4):
                new[r + 4*c] = state[r + 4*((c + r) % 4)]
        return new

    def inv_shift_rows(self, state):
        new = [0] * 16
        for r in range(4):
            for c in range(4):
                new[r + 4*c] = state[r + 4*((c - r) % 4)]
        return new

    def mix_columns(self, state):
        new = state.copy()
        for c in range(4):
            col = state[4*c:4*c+4]
            new[4*c + 0] = self.gf_mult(col[0], 2) ^ self.gf_mult(col[1], 3) ^ col[2] ^ col[3]
            new[4*c + 1] = col[0] ^ self.gf_mult(col[1], 2) ^ self.gf_mult(col[2], 3) ^ col[3]
            new[4*c + 2] = col[0] ^ col[1] ^ self.gf_mult(col[2], 2) ^ self.gf_mult(col[3], 3)
            new[4*c + 3] = self.gf_mult(col[0], 3) ^ col[1] ^ col[2] ^ self.gf_mult(col[3], 2)
        return new

    def inv_mix_columns(self, state):
        new = state.copy()
        for c in range(4):
            col = state[4*c:4*c+4]
            new[4*c + 0] = self.gf_mult(col[0], 0x0e) ^ self.gf_mult(col[1], 0x0b) ^ self.gf_mult(col[2], 0x0d) ^ self.gf_mult(col[3], 0x09)
            new[4*c + 1] = self.gf_mult(col[0], 0x09) ^ self.gf_mult(col[1], 0x0e) ^ self.gf_mult(col[2], 0x0b) ^ self.gf_mult(col[3], 0x0d)
            new[4*c + 2] = self.gf_mult(col[0], 0x0d) ^ self.gf_mult(col[1], 0x09) ^ self.gf_mult(col[2], 0x0e) ^ self.gf_mult(col[3], 0x0b)
            new[4*c + 3] = self.gf_mult(col[0], 0x0b) ^ self.gf_mult(col[1], 0x0d) ^ self.gf_mult(col[2], 0x09) ^ self.gf_mult(col[3], 0x0e)
        return new

    def add_round_key(self, state, key):
        return [b ^ k for b,k in zip(state,key)]

    def key_expansion(self, key):
        if len(key) != self.KEY_SIZE:
            raise ValueError("Key must be 32 bytes")
        expanded = list(key)
        i = self.KEY_SIZE
        rcon_i = 0
        while len(expanded) < self.BLOCK_SIZE * (self.NUM_ROUNDS + 1):
            temp = expanded[-4:]
            if i % self.KEY_SIZE == 0:
                # RotWord + SubWord + Rcon
                temp = temp[1:] + temp[:1]
                temp = [self.sbox[b] for b in temp]
                temp[0] ^= self.RCON[rcon_i]
                rcon_i += 1
            elif i % self.KEY_SIZE == 16:
                # SubWord only
                temp = [self.sbox[b] for b in temp]
            for j in range(4):
                expanded.append(expanded[i - self.KEY_SIZE + j] ^ temp[j])
            i += 4
        # split into round keys
        return [expanded[16*r:16*(r+1)] for r in range(self.NUM_ROUNDS+1)]

    def encrypt_block(self, plaintext):
        key = self.aes_256_key
        if len(plaintext) != self.BLOCK_SIZE or len(key) != self.KEY_SIZE:
            raise ValueError("Plaintext must be 16 bytes and key 32 bytes")
        state = list(plaintext)
        round_keys = self.key_expansion(key)
        state = self.add_round_key(state, round_keys[0])
        for rnd in range(1, self.NUM_ROUNDS):
            state = self.sub_bytes(state)
            state = self.shift_rows(state)
            state = self.mix_columns(state)
            state = self.add_round_key(state, round_keys[rnd])
        state = self.sub_bytes(state)
        state = self.shift_rows(state)
        state = self.add_round_key(state, round_keys[self.NUM_ROUNDS])
        return bytes(state)

    def decrypt_block(self, ciphertext):
        key = self.aes_256_key
        if len(ciphertext) != self.BLOCK_SIZE or len(key) != self.KEY_SIZE:
            raise ValueError("Ciphertext must be 16 bytes and key 32 bytes")
        state = list(ciphertext)
        round_keys = self.key_expansion(key)
        state = self.add_round_key(state, round_keys[self.NUM_ROUNDS])
        for rnd in range(self.NUM_ROUNDS-1, 0, -1):
            state = self.inv_shift_rows(state)
            state = self.inv_sub_bytes(state)
            state = self.add_round_key(state, round_keys[rnd])
            state = self.inv_mix_columns(state)
        state = self.inv_shift_rows(state)
        state = self.inv_sub_bytes(state)
        state = self.add_round_key(state, round_keys[0])
        return bytes(state)

    def encrypt(self, plaintext: bytes):
        # split to blocks
        blocks = []
        for i in range(0, len(plaintext), self.BLOCK_SIZE):
            blocks.append(plaintext[i:i + self.BLOCK_SIZE])
        # pad
        if blocks:
            blocks[-1] = blocks[-1] + b'\0' * (self.BLOCK_SIZE - len(blocks[-1]))
        # encrypt with ECB
        output = b''
        for block in blocks:
            if block == b'\0' * 16:
                raise Exception('Wanna encrypt null? What a terrible waste of resources!')
            output += self.encrypt_block(block)
        return output


def main():
    with open('flag', 'rb') as f:
        flag = f.read().strip()
    key = token_bytes(32)

    aes_crypt = BetterAES(key)
    ct = aes_crypt.encrypt(flag)
    print(f"Flag ciphertext: {ct.hex()}")

    try:
        print("Enter something you want to encrypt in hex form: ", end="")
        input_hex = input()

        if input_hex:
            if len(input_hex) <= 32:
                user_input = bytes.fromhex(input_hex)
                encrypted = aes_crypt.encrypt(user_input)
                print(f'Encrypted: {encrypted.hex()}')
            else:
                print('Input too long')
        else:
            print('Enter valid hex string')
    except ValueError:
        print('Invalid hex string')
    except Exception as e:
        print(e)
    print('Goodbye')


if __name__ == '__main__':
    main()

Here first thing that you is this

1
2
# for easier testing
    sbox= list(range(0, 256))

What does this mean? let’s see what is the function of sbox and how should it look like.

here is an explanation from cryptohack about how the sbox functions why it presents high non-linearity to AES but in our challenge the sbox is linear as $S[1] = 1$ so AES has transformed to an affine cipher where the encryption is Cipher_text = Transformed(Plain_text) + constant. The constant here is the key xored.

So now we only need to get the constant to get the flag but the code prohibits sending a cipher text consisting of zeros so how can over come this?

Let’s revisit the encryption equation: ct = transformed(pt) + constant here let’s understand what does the function transformed() do in AES you apply some transformations on the plain_text like shifting_rows and mixing_columns in the implementation of the challenge you have the function and it’s inverse so we can transform our plain text locally and the send the same plain text to the server to get ct = Transformed(pt) + constant here if we xor the ct with the locally transformed pt here we get the constant

so now we have the constant we only need to xor it with the encrypted flag then inverse the transformation applied to it then we get the flag.

here is my solution:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from better_aes import BetterAES
from pwn import remote, xor


class Cracken(BetterAES):
    def __init__(self):
        super(Cracken, self).__init__(bytes(32))

    def linear_transform(self, pt):
        state = list(pt)
        for rnd in range(1, self.NUM_ROUNDS):
            state = self.sub_bytes(state)
            state = self.shift_rows(state)
            state = self.mix_columns(state)
        state = self.sub_bytes(state)
        state = self.shift_rows(state)
        return bytes(state)

    def inv_linear_transform(self, ct):
        state = list(ct)
        for rnd in range(self.NUM_ROUNDS-1, 0, -1):
            state = self.inv_shift_rows(state)
            state = self.inv_sub_bytes(state)
            state = self.inv_mix_columns(state)
        state = self.inv_shift_rows(state)
        state = self.inv_sub_bytes(state)
        return bytes(state)

    def recover_flag(self, ct, constant):
        result = b''
        for i in range(0, len(ct), self.BLOCK_SIZE):
            block = ct[i:i+self.BLOCK_SIZE]
            transformed_flag = xor(block, constant)
            result += self.inv_linear_transform(transformed_flag)
        return result


def main():
    crack = Cracken()
    payload = b't4qi' * 4  
    transformed_payload = crack.linear_transform(payload)

    conn = remote("localhost", 13377)
    encrypted_flag_part = conn.recvuntil(b'Enter something you want to encrypt in hex form: ').decode()
    flag_ciphertext = bytes.fromhex(encrypted_flag_part.split()[2])

    conn.send(payload.hex().encode() + b'\n')
    encrypted_payload_response = conn.recvuntil(b'Goodbye').decode()
    encrypted_payload = bytes.fromhex(encrypted_payload_response.split(':')[1].split()[0])

    constant = xor(transformed_payload, encrypted_payload)
    flag = crack.recover_flag(flag_ciphertext, constant)
    print(flag.decode().strip('\0'))  


if __name__ == '__main__':
    main()

Hope you enjoyed the write-up.

Categories: