-->

Finding all partitions of a set in Java

2020-02-26 14:04发布

问题:

I have the following Python function to recursively find all partitions of a set:

def partitions(set_):
    if not set_:
        yield []
        return
    for i in xrange(2**len(set_)/2):
        parts = [set(), set()]
        for item in set_:
            parts[i&1].add(item)
            i >>= 1
        for b in partitions(parts[1]):
            yield [parts[0]]+b

for p in partitions(["a", "b", "c", "d"]):
    print(p)

Can someone help me to translate this into Java? This is what I have so far:

private static List<List<List<String>>> partitions(List<String> inputSet) {
    List<List<List<String>>> res = Lists.newArrayList();
    if (inputSet.size() == 0) {
        List<List<String>> empty = Lists.newArrayList();
        res.add(empty);
        return res;
    }
    int limit = (int)(Math.pow(2, inputSet.size())/2);
    for (int i = 0; i<limit; i++) {
        List<List<String>> parts = Lists.newArrayList();
        List<String> part1 = Lists.newArrayList();
        List<String> part2 = Lists.newArrayList();
        parts.add(part1);
        parts.add(part2);
        for (String item: inputSet) {
            parts.get(i&1).add(item);
            i >>= 1;
        }
        for (List<List<String>> b: partitions(parts.get(1))) {
            List<List<String>> set = Lists.newArrayList();
            set.add(parts.get(0));
            set.addAll(b);
            res.add(set);
        }
    }
    return res;
}

I get an infinite recursion when executing it with more than one element.

A post similar to this one (with Ruby) can be found here. The original Python code can be found here and here.

回答1:

You're very close to the right answer. You say you are getting infinite recursion, but in reality the program is running in an infinite loop in the outermost loop.

The primary difference from the Python code is that the i variable always advances in the outer loop in the Python version, but in your Java version, the i >>= 1 statement inside the inner loop always leaves i back at zero. The easy way to fix that is to simply use separate variables for the inner and outer loops.

In general, this is why it's a bad idea to try and directly translate a program from one language to another. Almost every program has some idioms that make sense in the original language that will be bizarre or meaningless in the target language. In particular, the Python code relies on implicit promotion to arbitrary precision integers for its correctness. This won't work well in Java, so the implementation below suffers from integer overflow if the input set is larger than 31 elements. Your example is only 4 elements, so for this specific case, it will produce the right answer.

Here's a corrected Java version:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class Partition {
    private static List<List<List<String>>> partitions(List<String> inputSet) {
        List<List<List<String>>> res = new ArrayList<>();
        if (inputSet.isEmpty()) {
            List<List<String>> empty = new ArrayList<>();
            res.add(empty);
            return res;
        }
        // Note that this algorithm only works if inputSet.size() < 31
        // since you overflow int space beyond that. This is true even
        // if you use Math.pow and cast back to int. The original
        // Python code does not have this limitation because Python
        // will implicitly promote to a long, which in Python terms is
        // an arbitrary precision integer similar to Java's BigInteger.
        int limit = 1 << (inputSet.size() - 1);
        // Note the separate variable to avoid resetting
        // the loop variable on each iteration.
        for (int j = 0; j < limit; ++j) {
            List<List<String>> parts = new ArrayList<>();
            List<String> part1 = new ArrayList<>();
            List<String> part2 = new ArrayList<>();
            parts.add(part1);
            parts.add(part2);
            int i = j;
            for (String item : inputSet) {
                parts.get(i&1).add(item);
                i >>= 1;
            }
            for (List<List<String>> b : partitions(part2)) {
                List<List<String>> holder = new ArrayList<>();
                holder.add(part1);
                holder.addAll(b);
                res.add(holder);
            }
        }
        return res;
    }

    public static void main(String[] args) {
        for (List<List<String>> partitions :
                 partitions(Arrays.asList("a", "b", "c", "d"))) {
            System.out.println(partitions);
        }
    }
}

Here's the output of my Java version:

[[a, b, c, d]]
[[b, c, d], [a]]
[[a, c, d], [b]]
[[c, d], [a, b]]
[[c, d], [b], [a]]
[[a, b, d], [c]]
[[b, d], [a, c]]
[[b, d], [c], [a]]
[[a, d], [b, c]]
[[a, d], [c], [b]]
[[d], [a, b, c]]
[[d], [b, c], [a]]
[[d], [a, c], [b]]
[[d], [c], [a, b]]
[[d], [c], [b], [a]]