Notice the way that the modulus is chosen. Three primes (p, q, and r) are chosen. The first modulus is na=p×q , and the second is nb=p×r.
Our first observation should be that both moduli share p as a common factor.
Recall that in RSA, for n=p×q, the private exponentd is chosen such that
ed≡1(modlcm(p−1,q−1))
i.e. for somek∈Z,
ed=1+k(p−1)(q−1)
For na=p×q and nb=p×r,
eda−1=ka(p−1)(q−1)edb−1=kb(p−1)(r−1)
Since we are given e, da and db, we know the values of ka(p−1)(q−1) and kb(p−1)(r−1). It is trivial to deduce that the greatest common divisor (GCD) between these two values is a multiple of (p−1). Let's call this α(p−1).
We can write the following code to obtain α(p−1):
from decimal import *
getcontext().prec = 1000
def gcd(a, b):
while b:
a, b = b, a % b
return a
de_a = d_a * e
de_b = d_b * e
p_multiple = Decimal(gcd(de_a - 1, de_b - 1))
We could then iteratively test for the value of α by asserting that p must be prime.
def isPrime(n, k=5): # miller-rabin
from random import randint
if n < 2: return False
for p in [2,3,5,7,11,13,17,19,23,29]:
if n % p == 0: return n == p
s, d = 0, n-1
while d % 2 == 0:
s, d = s+1, d//2
for i in range(k):
x = pow(randint(2, n-1), d, n)
if x == 1 or x == n-1: continue
for r in range(1, s):
x = (x * x) % n
if x == 1: return False
if x == n-1: break
else: return False
return True
for i in range(2, 100):
p = Decimal(p_multiple) / Decimal(i) + 1
if isPrime(int(p)):
print('p =', p)
break
After finding p, we simply note that
p−1eda−1=ka(q−1)
and use a similar method to find q. Then, we attempt to decode the ciphertext:
pa=(ca)dmodn
q_multiple = Decimal(de_a - 1) / (p - 1)
for i in range(2, 100000):
q = Decimal(q_multiple) / Decimal(i) + 1
if isPrime(int(q)):
m = pow(int(ct_a), int(d_a), int(p) * int(q))
msg = long_to_bytes(m)
try:
print(msg.decode())
break
except:
continue
The method to find r and decode the second part of the flag is exactly the same.
r_multiple = Decimal(de_b - 1) / (p - 1)
for i in range(2, 100000):
r = Decimal(r_multiple) / Decimal(i) + 1
if isPrime(int(r)):
m = pow(int(ct_b), int(d_b), int(p) * int(r))
msg = long_to_bytes(m)
try:
print(msg.decode())
break
except:
continue