Karatsuba Multiplication

Posted by Beetle B. on Wed 15 October 2014

In grade school we learn an algorithm for multiplying two \(N\) digit integers. Let our basic operation to be the multiplication or addition of 2 single digit numbers. Then the classic algorithm is \(\Theta(N^{2})\).

Divide and Conquer Approach

For simplicity, assume \(N\) is a power of 2. Let the first \(N/2\) digits of the first multiplicand be \(a\), and the remaining digits be \(b\). Likewise, define \(c\) and \(d\) for the 2nd number.

As a concrete example, if the first number was 1234 and the second 6789, then \(a=12,b=34,c=67,d=89\).

Noting that the product is \((10^{N/2}a+b)(10^{N/2}c+d)\), multiplying out, we get:

\begin{equation*} (10^{N/2}a+b)(10^{N/2}c+d)=10^{N}ac+10^{N/2}(ad+bc)+bd \end{equation*}

The process of multiplying a number by a power of 10 is trivial. So we’ve now reduced this to multiplying 4 numbers of size \(N/2\). We can recurse each of these until we reach the base case of single digits.

Complexity

So how does this algorithm compare to the classical algorithm? By the master theorem, we have \(\alpha=4,\beta=2\), and the cost to combine is \(\Theta(N)\), then the complexity is \(\Theta(N^{2})\).

That’s not really an improvement.

Karatsuba Multiplication

The Karatsuba algorithm is the same as the one in the previous section, but with one improvement. We know that \(ad+bc=(a+b)(c+d)-ac-bd\). So we can:

  1. Calculate \((a+b)(c+d)\) recursively.
  2. Calculate \(ac\) recursively.
  3. Calculate \(bd\) recursively.
  4. Subtract to get \(ad+bc\).

We require only 3 multiplications. So now \(\alpha=3\), and the complexity becomes \(\Theta(N^{\lg3})=\Theta(N^{1.585})\).

Beware that the coefficient may be significant. \(N\) may need to be quite big before it beats the classical algorithm.

Implementation

Python

I’ve included code for both the classical and the Karatsuba multiplication.

def add_basic(a, b, carry):
    """
    Add single digit numbers a, and b and add on any carry and return the result
    and the carry.
    """
    result = a + b + carry
    return result % 10, result / 10

def add(a, b):
    """
    Add two numbers.

    a and b are lists of integers with the least significant digit in front.
    """
    result = []
    carry = 0

    # Always make d the larger list.
    if len(a) < len(b):
        c, d = a, b
    else:
        c, d = b, a
    for index, v in enumerate(d):
        try:
            value, carry = add_basic(c[index], v, carry)
        except IndexError:
            value, carry = add_basic(0, v, carry)
        result.append(value)
    if carry != 0:
        result.append(carry)
    return result   

def subtract_basic(a, b, carry):
    """
    Subtract single digit numbers b from a and take care of any "carry".
    """
    a = a - carry
    if b <= a:
        result = a - b
        carry = 0
    else:
        result = 10 + a - b
        carry = 1
    return result, carry

def subtract(a, b):
    """
    Subtract two numbers. Return a-b

    a and b are lists of integers with the least significant digit in front.
    """
    result = []
    carry = 0
    c, d = b, a
    for index, v in enumerate(d):
        try:
            value, carry = subtract_basic(v, c[index], carry)
        except IndexError:
            value, carry = subtract_basic(v, 0, carry)
        result.append(value)
    if carry != 0:
        raise Exception
    return result   

def multiply_basic(a, b, carry):
    """
    Multiply single digit numbers, incorporating any carry.
    """
    result = a * b + carry
    return result % 10, result / 10

def classic_multiply(N1, N2):
    """
    Classic grade school algorithm to multiply two numbers.
    
    N1, N2 are strings.
    """
    # We'll store numbers as lists, with the first element being the least
    # significant digit.
    num1 = list(map(int, N1[::-1]))
    num2 = list(map(int, N2[::-1]))
    return "".join(map(str, classic_multiply_helper(num1, num2)[::-1]))

def classic_multiply_helper(N1, N2):
    """
    Classic grade school algorithm to multiply two numbers.
    
    N1, N2 are lists of integers. The least significant digit is the first
    element.
    """
    num1, num2 = N1, N2
    # Always make d the longer number.
    if len(num1) < len(num2):
        c, d = num1, num2
    else:
        c, d = num2, num1
    
    sums = []
    for ind, val1 in enumerate(c):
        carry = 0
        tmp_sum = [0] * ind # Padded 0's.
        for val2 in d:
            result, carry = multiply_basic(val1, val2, carry)
            tmp_sum.append(result)
        if carry != 0:
            tmp_sum.append(carry)
        sums.append(tmp_sum)
    # Add all the intermediate sums together using our custom add function.
    final = reduce(lambda x, y: add(x, y), sums, [0])
    return final        

def karatsuba(N1, N2):
    """
    Karatsuba algorithm to multiply two numbers.
    
    N1, N2 are strings.
    """
    num1 = list(map(int, N1[::-1]))
    num2 = list(map(int, N2[::-1]))
    return "".join(map(str, karatsuba_helper(num1, num2)[::-1]))

def karatsuba_helper(N1, N2):
    """
    N1, N2 are lists of integers. The least significant digit is the first
    element.

    Split N1 into a and b, and N2 into c and d. 
    """
    l1 = len(N1)
    l2 = len(N2)
    if (l1 == 0) or (l2 == 0):
        return [0]
    if (l1 == 1) or (l2 == 1):
        return classic_multiply_helper(N1, N2)
    new_l1 = max(l1, l2)/2
    a, b = N1[new_l1:], N1[:new_l1]
    c, d = N2[new_l1:], N2[:new_l1]

    a_plus_b = add(a, b)
    c_plus_d = add(c, d)
    ac = karatsuba_helper(a, c)
    bd = karatsuba_helper(b, d)
    # big is (a+b)(c+d)
    big = karatsuba_helper(a_plus_b, c_plus_d)
    ac_plus_bd = add(ac, bd)
    # middle is (a+b)(c+d) - ac - bd
    middle = subtract(big, ac_plus_bd)

    zeros_ac = 2 * new_l1
    zeros_middle = new_l1
    r1 = [0]*zeros_ac
    r1.extend(ac)
    r2 = [0]*zeros_middle
    r2.extend(middle)
    r3 = bd
    s1 = add(r1, r2)
    s2 = add(s1, r3)
    return s2
    

Performance

As you can see, the Karatsuba algorithm seems like it will never overtake the classical algorithm.

Either I’m doing something wrong, or we’ll need a huge \(N\) to have the Karatsuba algorithm overtake. Either way, my implementation in Python is useless. It’s not often I’ll have to multiply two 1000 digit numbers.

Strassen’s Algorithm

Similar to multiplying two numbers, we can reduce the complexity of multiplying two \(N\times N\) matrices. The classic algorithm takes \(\Theta(N^{3})\). But Strassen’s Algorithm does it faster. I didn’t bother writing the gory details, but the principle is the same: Divide and conquer, and one of the 8 products is actually the sum/difference of 7 products.

Supposedly, with modern architecture, no implementation will beat the classical algorithm for any problem of practical size. I think the claim was that for the \(N\) needed for Strassen’s algorithm to be faster, modern computers simply don’t have enough RAM to handle the algorithm. This wasn’t the case some decades ago, of course.