Search

Dark theme | Light theme

January 25, 2021

Java Joy: Partition Stream By Predicate

The Java Stream API has many useful methods. If we want to partition a stream of objects by a given predicate we can use the partitioningBy() method from the java.util.stream.Collectors package. We must use this method in the collect() method of the stream. The result is a Map with the keys true and false. The objects from the stream that are true for the predicate will end up in the true value list and if the result of the predicate is false the value will end up in the list of values for the false key. The partitionBy method accepts a collector as second parameter. This collector will be applied to the values before they are put in the true or false keys in the result.

In the following example we use the partitioningBy method with different streams:

package mrhaki.stream;

import java.util.stream.Stream;
import java.util.stream.Collectors;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class Main {
    public static void main(String[] args) {
        // Let's create an infinitive stream of integers.
        var range = Stream.iterate(0, i -> i + 1);

        // We can partition them by odd and even numbers
        // using a predicate i % 2 == 0, where i is integer from the stream.
        var oddEvenNumbers = 
            range.limit(10)
                  .collect(Collectors.partitioningBy(i -> i % 2 == 0));

        // Even numbers are assigned to the "true" key,
        // odd numbers to the "false" key.
        var odds = oddEvenNumbers.get(false);
        var evens = oddEvenNumbers.get(true);

        assert odds.size() == 5;
        assert odds.equals(List.of(1, 3, 5, 7, 9));
        assert evens.size() == 5;
        assert evens.equals(List.of(0, 2, 4, 6, 8));

        // We use a second argument to sum all odd and even numbers.
        var sumEvenOdd = 
            Stream.iterate(0, i -> i + 1)
                  .limit(100)
                  .collect(
                      Collectors.partitioningBy(i -> i % 2 == 0, 
                                                Collectors.reducing(0, (result, i) -> result += i)));

        assert sumEvenOdd.get(true) == 2450;
        assert sumEvenOdd.get(false) == 2500;
        

        // In the next exmample we start with an immutable map.
        var map = Map.of("language", "Java", "username", "mrhaki", "age", 47);

        // This time we partition on the type of the value where values
        // of type String are assigned to "true" and other types to "false".
        var partitionByStringValue = 
            map.entrySet()
               .stream()
               .collect(Collectors.partitioningBy(entry -> entry.getValue() instanceof String));

        var stringValueEntries = partitionByStringValue.get(true);
        var nonStringValueEntries = partitionByStringValue.get(false);

        assert stringValueEntries.size() == 2;

        var keys = stringValueEntries.stream().map(Map.Entry::getKey).collect(Collectors.toUnmodifiableList());
        var values = stringValueEntries.stream().map(Map.Entry::getValue).collect(Collectors.toUnmodifiableList());
        assert keys.containsAll(List.of("language", "username"));
        assert values.containsAll(List.of("Java", "mrhaki"));

        assert nonStringValueEntries.size() == 1;
        assert nonStringValueEntries.get(0).getKey().equals("age");
        assert nonStringValueEntries.get(0).getValue().equals(47);       
    }
}

Written with Java 15.0.1.