9

Edit Jan 31: important special case is when the sums form a nested structure, search for "Hasse diagram is a tree" below

Here's a practically relevant variation on matrix chain problem:

Find optimal way to compute a sum over all weighted paths in a graph, where the weight of each path is the matrix product of edge labels (ie, a matrix chain)

For instance, take $Q$ corresponding to the sum of $$Q=A_0 A_1 A_2 A_3 A_4 +A_0 A_1 A_3 A_4 +A_0 A_1 A_4$$

We can represent it as the following sum over weighted paths

enter image description here

Now, we can count the number of matrix multiplications (in red) involved in computing this sum

enter image description here

Here's a more efficient way to compute $Q$

$$Q=A_0A_1(A_2+I)A_3A_4 + A_0A_1A_4$$

We can view it as following sum over paths: enter image description here

Some edges are labeled with I, corresponding to multiplication by the identity matrix.

Given a list of matrix dimensions $d_0,\ldots,d_n$ corresponding to matrices $A_0,\ldots,A_{n-1}$ and a list of paths, the task is to figure out a sequence of matrix multiplications and additions which produces $Q$ using the smallest total number of scalar multiplications.

Specifying Paths

Each term in the sum is specified as a pair of two numbers $(i,j)$, indicating that matrices $A_{i+1},A_{i+2},\ldots,A_j$ are not present in the term.

For instance, for problem above, paths are $[(), (1,2), (1,3)]$. The second term is missing matrices $\{A_2\}$ and the third term is missing $\{A_2,A_3\}$. Viewing paths as connected subgraphs of the chain graph, paths are partially ordered using subgraph relation. Therefore, the set of paths forms a lattice.

enter image description here

An important special case is when the Hasse diagram is a tree.

For example, consider this sum.

$$W=A_0 A_1 A_3 A_4 + A_0 A_1 A_2 A_4 + A_0 A_1 A_4 + A_0 A_1 A_2 A_3$$

enter image description here

And the corresponding Hasse diagram:

enter image description here

Coming back to original example $Q$, we can order matrix products in the following way: $$Q=(A_0 A_1) A_4 + (A_0 A_1) (A_3 A_4) + ((A_0 A_1) A_2) (A_3 A_4)$$

Notice that some terms like $A_0 A_1$ are repeated, hence reuse this computation by introducing temporary variables like $T_0=A_0 A_1$.

Now we can visualize a particular schedule below, and use this colab to find that it requires 7 scalar multiplications when $d_0=d_1=\ldots=1$

enter image description here

T0=matmul(A0, A1)
T1=matmul(T0,A4)
T2=matmul(A3,A4)
T3=matmul(T0,T2)
T4=add(T1,T3)
T5=matmul(T0,A2)
T6=matmul(T5,A3)
T7=matmul(T6,A4)
Q=add(T7,T4)

Now the question is -- how hard is this problem to solve or approximate:

  1. when $d_0=d_1=\ldots=d_n=1$
  2. when $d_i$'s are arbitrary positive integers

Edit Feb 17 Example of problematic split mentioned in the comments enter image description here

Yaroslav Bulatov
  • 201
  • 2
  • 12

1 Answers1

3

One idea would be to generalize the O(N^3) DP for the case of a single path without skips to your case:

d[i][j] would be the cost of computing all the products of matrices [A_i ... A_j) for all the paths for which all their missing matrices lie fully within the range. It can be done by iterating over a location to split the range [i, j), recursively computing the best cost for the left and the right part. If some subtree lies entirely on one side of the split point, then I believe it just works. If the split point is in the middle of some subtree (and there will be at most one such subtree), the DP needs to be called recursively for the subtree, and the result needs to be added to the cost.

Below I have the code, without the proper handing of subtrees. For your particular example if already improves from 7 to 6 multiplications.

...

def compute_schedule(tensors, paths, cache, l = 0, r = -1): if r == -1: r = len(tensors)

if r == l + 1:
    return tensors[l], "A%s" % l

if (l, r) in cache:
    return cache[(l, r)]

ret = None
for mid in range(l + 1, r):
    lm, ls = compute_schedule(tensors, paths, cache, l, mid)
    rm, rs = compute_schedule(tensors, paths, cache, mid, r)
    cur = [lm @ rm, "%s @ %s" % (ls, rs)]

    # If the Hasse diagram is a tree, the `mid` will lie within at most one
    # child. The correct thing to do here would be to call the DP for such 
    # subtree.
    # I don't do it here, and instead just compute the naive product from
    # left to right for all the runs in such subtree. This is not optimal
    # in general case.
    need_parens = False
    for path in paths:
        if l < path[0] and mid >= path[0] and mid < path[1] and r >= path[1]:
            t = tensors[l]
            ts = "A%s" % l
            for k in range(l + 1, r):
                if k < path[0] or k >= path[1]:
                    t @= tensors[k] # don't actually need to multiply tensors here, enough to compute the shape
                    ts += " @ A%s" % k
            cur[0] += t
            cur[1] += " + %s" % ts
            need_parens = True

    if need_parens:
        cur[1] = "(" + cur[1] + ")"

    cur = tuple(cur)
    if ret is None or cost(cur[0]) < cost(ret[0]):
        ret = cur

cache[(l, r)] = ret
return ret


d = 1 ones = np.ones((d, d)) (A0, A1, A2, A3, A4) = (LeafNode(k * ones) for k in (1, 2, 3, 4, 5)) Q = A0 @ A1 @ A2 @ A3 @ A4 + A0 @ A1 @ A3 @ A4 + A0 @ A1 @ A4

Q2, Q2_str = compute_schedule([A0, A1, A2, A3, A4], [(2, 3), (2, 4)], {}) print(Q2_str)

np.testing.assert_almost_equal(Q.value, Q2.value) assert cost(Q) == 9 assert cost(Q2) == 4 print(f"Schedule2 requires {cost(Q2)} scalar multiplications")

Yields

A0 @ ((A1 @ A2 + A1) @ A3 + A1) @ A4
Schedule2 requires 4 scalar multiplications
Ishamael
  • 331
  • 1
  • 5