Twist and Shout
Recovering the internal state of Python's Mersenne Twister PRNG.

Problem

Wise men once said, "Well, shake it up, baby, now Twist and shout come on and work it on out" I obliged, now the flag is as twisted as my sense of humour nc crypto.zh3r0.cf 5555

Solution

We are given the following source code:
1
from secret import flag
2
from Crypto.Util.number import *
3
import os
4
import random
5
​
6
state_len = 624*4
7
right_pad = random.randint(0,state_len-len(flag))
8
left_pad = state_len-len(flag)-right_pad
9
state_bytes = os.urandom(left_pad)+flag+os.urandom(right_pad)
10
​
11
state = tuple( int.from_bytes(state_bytes[i:i+4],'big') for i in range(0,state_len,4) )
12
​
13
random.setstate((3,state+(624,),None))
14
random.randint(0,0)
15
outputs = [random.getrandbits(32) for i in range(624)]
16
print(*outputs,sep='\n')
Copied!
A few things here:
  1. 1.
    The state tuple has a fixed length of 624 * 4, and the flag is hidden inside.
  2. 2.
    Python's random pseudo-random number generator (PRNG) state is set to the state tuple, with an additional number 624 at the back.
  3. 3.
    Then, 624 32-bit integers are generated using the PRNG and printed.

Pseudo-RNGs

Note that the left and right padding use os.urandom().
1
state_bytes = os.urandom(left_pad)+flag+os.urandom(right_pad)
Copied!
This is the cryptographically secure way of generating random numbers in Python. It draws its source of entropy from many real-world unpredictable sources, making it random.
The random module, on the other hand, implements a deterministic PRNG. Deterministic PRNGs are predictable. For instance, when using the same seed, the "random" numbers will be the same each time.

Mersenne Twister

In Python, random is implemented using the Mersenne Twister. Basically, the RNG works on an internal state of 624 32-bit values. The generator also keeps track of the current position i in the state array, and each "random number" is essentially state[i] after some mangling.
If we look at the CPython source code, we can see exactly how this is implemented:
1
static unsigned long
2
genrand_int32(RandomObject *self)
3
{
4
unsigned long y;
5
static unsigned long mag01[2]={0x0UL, MATRIX_A};
6
/* mag01[x] = x * MATRIX_A for x=0,1 */
7
unsigned long *mt;
8
​
9
mt = self->state;
10
if (self->index >= N) { /* generate N words at one time */
11
int kk;
12
​
13
for (kk=0;kk<N-M;kk++) {
14
y = (mt[kk]&UPPER_MASK)|(mt[kk+1]&LOWER_MASK);
15
mt[kk] = mt[kk+M] ^ (y >> 1) ^ mag01[y & 0x1UL];
16
}
17
for (;kk<N-1;kk++) {
18
y = (mt[kk]&UPPER_MASK)|(mt[kk+1]&LOWER_MASK);
19
mt[kk] = mt[kk+(M-N)] ^ (y >> 1) ^ mag01[y & 0x1UL];
20
}
21
y = (mt[N-1]&UPPER_MASK)|(mt[0]&LOWER_MASK);
22
mt[N-1] = mt[M-1] ^ (y >> 1) ^ mag01[y & 0x1UL];
23
​
24
self->index = 0;
25
}
26
​
27
y = mt[self->index++];
28
y ^= (y >> 11);
29
y ^= (y << 7) & 0x9d2c5680UL;
30
y ^= (y << 15) & 0xefc60000UL;
31
y ^= (y >> 18);
32
return y;
33
}
Copied!
The if statement checks if the index is larger than the size of the array, in which case the state array needs to be regenerated to the "next state".
Otherwise, we can see that it simply does the following to the number at the current index:
1
y = mt[self->index++];
2
y ^= (y >> 11);
3
y ^= (y << 7) & 0x9d2c5680UL;
4
y ^= (y << 15) & 0xefc60000UL;
5
y ^= (y >> 18);
6
return y;
Copied!

Internal State

Let's take a look at these two lines of the source code:
1
state = tuple( int.from_bytes(state_bytes[i:i+4],'big') for i in range(0,state_len,4) )
2
random.setstate((3,state+(624,),None))
Copied!
random.setstate() allows us to set a state to control the PRNG. We know that this consists of the state array, but what exactly is the 624 at the back?
The Python documentation doesn't say much, just that:
state should have been obtained from a previous call to getstate(), and setstate() restores the internal state of the generator to what it was at the time getstate() was called.
and that getstate() will
Return an object capturing the current internal state of the generator. This object can be passed to setstate() to restore the state.
Well, that doesn't really help, but again, the CPython source code gives us some answers.
1
static PyObject *
2
random_getstate(RandomObject *self)
3
{
4
PyObject *state;
5
PyObject *element;
6
int i;
7
​
8
state = PyTuple_New(N+1);
9
if (state == NULL)
10
return NULL;
11
for (i=0; i<N ; i++) {
12
element = PyLong_FromUnsignedLong(self->state[i]);
13
if (element == NULL)
14
goto Fail;
15
PyTuple_SET_ITEM(state, i, element);
16
}
17
element = PyLong_FromLong((long)(self->index));
18
if (element == NULL)
19
goto Fail;
20
PyTuple_SET_ITEM(state, i, element);
21
return state;
22
​
23
Fail:
24
Py_DECREF(state);
25
return NULL;
26
}
Copied!
Notice how the last element of the state tuple is set? It is set to the value of self->index. And we know from the above that the index refers to the current position in the state array.

Recovering the Internal State

The key idea is that since the state array consists of 624 32-bit integers, we only need 624 32-bit outputs to undo the above mangling and recover the state array.
Credits to More Smoked Leet Chicken for this untempering script! It is taken from http://mslc.ctf.su/wp/confidence-ctf-2015-rsa2-crypto-500/.
1
#-*- coding:utf-8 -*-
2
​
3
TemperingMaskB = 0x9d2c5680
4
TemperingMaskC = 0xefc60000
5
​
6
def untemper(y):
7
y = undoTemperShiftL(y)
8
y = undoTemperShiftT(y)
9
y = undoTemperShiftS(y)
10
y = undoTemperShiftU(y)
11
return y
12
​
13
def undoTemperShiftL(y):
14
last14 = y >> 18
15
final = y ^ last14
16
return final
17
​
18
def undoTemperShiftT(y):
19
first17 = y << 15
20
final = y ^ (first17 & TemperingMaskC)
21
return final
22
​
23
def undoTemperShiftS(y):
24
a = y << 7
25
b = y ^ (a & TemperingMaskB)
26
c = b << 7
27
d = y ^ (c & TemperingMaskB)
28
e = d << 7
29
f = y ^ (e & TemperingMaskB)
30
g = f << 7
31
h = y ^ (g & TemperingMaskB)
32
i = h << 7
33
final = y ^ (i & TemperingMaskB)
34
return final
35
​
36
def undoTemperShiftU(y):
37
a = y >> 11
38
b = y ^ a
39
c = b >> 11
40
final = y ^ c
41
return final
Copied!
After receiving the 624 outputs from the server, we can store them in an outputs array and recover the original state:
1
from mt import untemper
2
​
3
mt_state = tuple(list(map(untemper, outputs)) + [0])
4
random.setstate((3, mt_state, None))
5
outputs2 = [random.getrandbits(32) for i in range(624)]
6
​
7
# Sanity check
8
for i in range(len(outputs2)):
9
assert outputs2[i] == outputs[i]
Copied!
If the sanity check passes, we have successfully recovered the original state of the MT PRNG. However, our work is not done! Remember how the number 624 was added to the back of the state tuple?
1
random.setstate((3,state+(624,),None))
Copied!
Well, looking back at the CPython source above, we know that this means that before the first random output is even generated, the state array was reconstructed.
1
if (self->index >= N) { /* generate N words at one time */
2
int kk;
3
​
4
for (kk=0;kk<N-M;kk++) {
5
y = (mt[kk]&UPPER_MASK)|(mt[kk+1]&LOWER_MASK);
6
mt[kk] = mt[kk+M] ^ (y >> 1) ^ mag01[y & 0x1UL];
7
}
8
for (;kk<N-1;kk++) {
9
y = (mt[kk]&UPPER_MASK)|(mt[kk+1]&LOWER_MASK);
10
mt[kk] = mt[kk+(M-N)] ^ (y >> 1) ^ mag01[y & 0x1UL];
11
}
12
y = (mt[N-1]&UPPER_MASK)|(mt[0]&LOWER_MASK);
13
mt[N-1] = mt[M-1] ^ (y >> 1) ^ mag01[y & 0x1UL];
14
​
15
self->index = 0;
16
}
Copied!
The state we obtained from our script above is from unmangling the previous 624 outputs, therefore giving us a state array that starts from index 0. This is exactly the state array that would be generated after the MT generator notices that the current position in the array is 624.

Recovering the Previous State

What we need to do, then, is to recover the previous state of the generator. I found this great post containing an algorithm to recover the previous state array.
The algorithm looks like this:
1
for (int i = 623; i >= 0; i--) {
2
int result = 0;
3
// first we calculate the first bit
4
int tmp = state[i];
5
tmp ^= state[(i + 397) % 624];
6
// if the first bit is odd, unapply magic
7
if ((tmp & 0x80000000) == 0x80000000) {
8
tmp ^= 0x9908b0df;
9
}
10
// the second bit of tmp is the first bit of the result
11
result = (tmp << 1) & 0x80000000;
12
​
13
// work out the remaining 31 bits
14
tmp = state[(i - 1 + 624) % 624];
15
tmp ^= state[(i + 396) % 624];
16
if ((tmp & 0x80000000) == 0x80000000) {
17
tmp ^= 0x9908b0df;
18
// since it was odd, the last bit must have been 1
19
result |= 1;
20
}
21
// extract the final 30 bits
22
result |= (tmp << 1) & 0x7fffffff;
23
state[i] = result;
Copied!
We can then recover the previous state:
1
state = tuple( int.from_bytes(state_bytes[i:i+4],'big') for i in range(0,state_len,4) )
2
random.setstate((3,state+(624,),None)) # This state has index 624
3
​
4
...
5
​
6
# From the state with index 0, recover previous state with index 624.
7
def get_prev_state(state):
8
for i in range(623, -1, -1):
9
result = 0
10
tmp = state[i]
11
tmp ^= state[(i + 397) % 624]
12
if ((tmp & 0x80000000) == 0x80000000):
13
tmp ^= 0x9908b0df
14
result = (tmp << 1) & 0x80000000
15
16
tmp = state[(i - 1 + 624) % 624]
17
tmp ^= state[(i + 396) % 624]
18
if ((tmp & 0x80000000) == 0x80000000):
19
tmp ^= 0x9908b0df
20
result |= 1
21
22
result |= (tmp << 1) & 0x7fffffff
23
state[i] = result
24
25
return state
26
27
prev_state = get_prev_state(list(mt_state[:624]))
28
​
29
# Sanity check
30
for i in range(1, len(state)):
31
assert state[i] == prev_state[i]
Copied!
Sidenote: recovering the previous state essentially allows us to obtain "past" outputs. Being able to know both past and future outputs can be a serious security flaw in real-world applications. In a real application, we might obtain the required 624 outputs to recover the internal state of the PRNG via consecutive web requests, etc.

Solving the Challenge

Once we obtain the original state, we simply have to convert the numbers in the tuple to their corresponding bytes and look for the flag in the output.
1
from Crypto.Util.number import *
2
result = b""
3
for num in prev_state:
4
result += long_to_bytes(num)
5
6
print(result)
Copied!
Here's the complete solver script:
1
from pwn import *
2
from Crypto.Util.number import *
3
from mt import untemper
4
​
5
conn = remote('crypto.zh3r0.cf', 5555)
6
​
7
outputs = []
8
for i in range(624):
9
num = int(conn.recvline().decode().strip())
10
outputs.append(num)
11
12
mt_state = tuple(list(map(untemper, outputs)) + [0])
13
​
14
def get_prev_state(state):
15
for i in range(623, -1, -1):
16
result = 0
17
tmp = state[i]
18
tmp ^= state[(i + 397) % 624]
19
if ((tmp & 0x80000000) == 0x80000000):
20
tmp ^= 0x9908b0df
21
result = (tmp << 1) & 0x80000000
22
23
tmp = state[(i - 1 + 624) % 624]
24
tmp ^= state[(i + 396) % 624]
25
if ((tmp & 0x80000000) == 0x80000000):
26
tmp ^= 0x9908b0df
27
result |= 1
28
29
result |= (tmp << 1) & 0x7fffffff
30
state[i] = result
31
32
return state
33
34
prev_state = get_prev_state(list(mt_state[:624]))
35
​
36
result = b""
37
for num in prev_state:
38
result += long_to_bytes(num)
39
40
print(result)
Copied!
And the output contains the flag:
Last modified 7mo ago