Skip to content

Stream API

Stream概述

Stream是Java 8引入的函数式数据处理API,支持链式操作。

创建Stream

java
import java.util.stream.*;
import java.util.*;

public class CreateStream {
    public static void main(String[] args) {
        // 从集合创建
        List<String> list = Arrays.asList("a", "b", "c");
        Stream<String> stream1 = list.stream();
        Stream<String> parallelStream = list.parallelStream();
        
        // 从数组创建
        String[] array = {"a", "b", "c"};
        Stream<String> stream2 = Arrays.stream(array);
        
        // 使用Stream.of
        Stream<String> stream3 = Stream.of("a", "b", "c");
        
        // 创建空Stream
        Stream<String> emptyStream = Stream.empty();
        
        // 无限流
        Stream<Double> randoms = Stream.generate(Math::random);
        Stream<Integer> naturals = Stream.iterate(0, n -> n + 1);
        
        // Java 9+ 带条件的无限流
        Stream<Integer> limitedNaturals = Stream.iterate(0, n -> n < 100, n -> n + 1);
    }
}

中间操作

filter - 过滤

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5, 6);

// 过滤偶数
List<Integer> evens = numbers.stream()
    .filter(n -> n % 2 == 0)
    .collect(Collectors.toList());
// [2, 4, 6]

map - 映射

java
List<String> names = Arrays.asList("张三", "李四", "王五");

// 获取名字长度
List<Integer> lengths = names.stream()
    .map(String::length)
    .collect(Collectors.toList());
// [2, 2, 2]

// 转大写
List<String> upperNames = names.stream()
    .map(String::toUpperCase)
    .collect(Collectors.toList());

flatMap - 扁平化

java
List<List<Integer>> nestedLists = Arrays.asList(
    Arrays.asList(1, 2),
    Arrays.asList(3, 4),
    Arrays.asList(5, 6)
);

// 扁平化为一维列表
List<Integer> flattened = nestedLists.stream()
    .flatMap(Collection::stream)
    .collect(Collectors.toList());
// [1, 2, 3, 4, 5, 6]

// 分割字符串为字符
List<String> words = Arrays.asList("Hello", "World");
List<String> chars = words.stream()
    .flatMap(word -> Arrays.stream(word.split("")))
    .distinct()
    .collect(Collectors.toList());
// [H, e, l, o, W, r, d]

distinct - 去重

java
List<Integer> numbers = Arrays.asList(1, 2, 2, 3, 3, 3);
List<Integer> distinct = numbers.stream()
    .distinct()
    .collect(Collectors.toList());
// [1, 2, 3]

sorted - 排序

java
List<Integer> numbers = Arrays.asList(3, 1, 4, 1, 5, 9);

// 自然排序
List<Integer> sorted = numbers.stream()
    .sorted()
    .collect(Collectors.toList());
// [1, 1, 3, 4, 5, 9]

// 自定义排序
List<Integer> reverseSorted = numbers.stream()
    .sorted(Comparator.reverseOrder())
    .collect(Collectors.toList());
// [9, 5, 4, 3, 1, 1]

// 按属性排序
List<Person> people = getPeople();
List<Person> sortedPeople = people.stream()
    .sorted(Comparator.comparing(Person::getAge))
    .collect(Collectors.toList());

limit / skip - 限制/跳过

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

// 取前5个
List<Integer> first5 = numbers.stream()
    .limit(5)
    .collect(Collectors.toList());
// [1, 2, 3, 4, 5]

// 跳过前5个
List<Integer> skip5 = numbers.stream()
    .skip(5)
    .collect(Collectors.toList());
// [6, 7, 8, 9, 10]

// 分页
int page = 2, pageSize = 3;
List<Integer> page2 = numbers.stream()
    .skip((page - 1) * pageSize)
    .limit(pageSize)
    .collect(Collectors.toList());
// [4, 5, 6]

peek - 查看元素

java
List<Integer> result = Stream.of(1, 2, 3, 4, 5)
    .peek(n -> System.out.println("处理: " + n))
    .filter(n -> n % 2 == 0)
    .peek(n -> System.out.println("过滤后: " + n))
    .collect(Collectors.toList());

终端操作

collect - 收集

java
List<String> names = Arrays.asList("张三", "李四", "王五", "赵六");

// 收集为List
List<String> list = names.stream().collect(Collectors.toList());

// 收集为Set
Set<String> set = names.stream().collect(Collectors.toSet());

// 收集为Map
Map<String, Integer> map = names.stream()
    .collect(Collectors.toMap(name -> name, String::length));

// 收集为指定集合
LinkedList<String> linkedList = names.stream()
    .collect(Collectors.toCollection(LinkedList::new));

// 连接字符串
String joined = names.stream()
    .collect(Collectors.joining(", "));
// 张三, 李四, 王五, 赵六

统计操作

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5);

// 数量
long count = numbers.stream().count();  // 5

// 最大值
Optional<Integer> max = numbers.stream().max(Integer::compare);  // 5

// 最小值
Optional<Integer> min = numbers.stream().min(Integer::compare);  // 1

// 求和
int sum = numbers.stream().mapToInt(Integer::intValue).sum();  // 15

// 平均值
OptionalDouble avg = numbers.stream().mapToInt(Integer::intValue).average();

// 统计摘要
IntSummaryStatistics stats = numbers.stream()
    .mapToInt(Integer::intValue)
    .summaryStatistics();
System.out.println(stats.getCount());   // 5
System.out.println(stats.getSum());     // 15
System.out.println(stats.getAverage()); // 3.0
System.out.println(stats.getMax());     // 5
System.out.println(stats.getMin());     // 1

reduce - 归约

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5);

// 求和
int sum = numbers.stream()
    .reduce(0, Integer::sum);  // 15

// 无初始值
Optional<Integer> sum2 = numbers.stream()
    .reduce(Integer::sum);

// 求最大值
Optional<Integer> max = numbers.stream()
    .reduce(Integer::max);

// 字符串拼接
List<String> words = Arrays.asList("Hello", "World");
String sentence = words.stream()
    .reduce("", (a, b) -> a + " " + b);

forEach - 遍历

java
List<String> names = Arrays.asList("张三", "李四", "王五");

names.stream().forEach(System.out::println);

// 等价于
names.forEach(System.out::println);

匹配操作

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5);

// 是否全部匹配
boolean allPositive = numbers.stream()
    .allMatch(n -> n > 0);  // true

// 是否任一匹配
boolean anyEven = numbers.stream()
    .anyMatch(n -> n % 2 == 0);  // true

// 是否全部不匹配
boolean noneNegative = numbers.stream()
    .noneMatch(n -> n < 0);  // true

查找操作

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5);

// 查找第一个
Optional<Integer> first = numbers.stream()
    .filter(n -> n > 2)
    .findFirst();  // 3

// 查找任意一个(并行流更高效)
Optional<Integer> any = numbers.stream()
    .filter(n -> n > 2)
    .findAny();

分组和分区

分组

java
List<Person> people = Arrays.asList(
    new Person("张三", 25, "北京"),
    new Person("李四", 30, "上海"),
    new Person("王五", 25, "北京"),
    new Person("赵六", 30, "上海")
);

// 按城市分组
Map<String, List<Person>> byCity = people.stream()
    .collect(Collectors.groupingBy(Person::getCity));

// 按年龄分组
Map<Integer, List<Person>> byAge = people.stream()
    .collect(Collectors.groupingBy(Person::getAge));

// 多级分组
Map<String, Map<Integer, List<Person>>> byCityAndAge = people.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.groupingBy(Person::getAge)
    ));

// 分组统计数量
Map<String, Long> countByCity = people.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.counting()
    ));

// 分组求平均年龄
Map<String, Double> avgAgeByCity = people.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.averagingInt(Person::getAge)
    ));

分区

java
// 按条件分为两组
Map<Boolean, List<Person>> partitioned = people.stream()
    .collect(Collectors.partitioningBy(p -> p.getAge() >= 30));

// 分区统计
Map<Boolean, Long> countByAge = people.stream()
    .collect(Collectors.partitioningBy(
        p -> p.getAge() >= 30,
        Collectors.counting()
    ));

并行流

java
List<Integer> numbers = IntStream.range(1, 1000000).boxed().collect(Collectors.toList());

// 串行流
long start = System.currentTimeMillis();
long sum1 = numbers.stream()
    .mapToLong(n -> n * n)
    .sum();
System.out.println("串行: " + (System.currentTimeMillis() - start) + "ms");

// 并行流
start = System.currentTimeMillis();
long sum2 = numbers.parallelStream()
    .mapToLong(n -> n * n)
    .sum();
System.out.println("并行: " + (System.currentTimeMillis() - start) + "ms");

// 转换为并行流
Stream<Integer> parallel = numbers.stream().parallel();

实战示例

数据处理

java
// 过滤、排序、去重、分页
List<Person> result = people.stream()
    .filter(p -> p.getAge() > 20)
    .sorted(Comparator.comparing(Person::getAge).reversed())
    .distinct()
    .skip(0)
    .limit(10)
    .collect(Collectors.toList());

嵌套处理

java
// 获取所有订单中的所有商品名称
List<Order> orders = getOrders();
List<String> productNames = orders.stream()
    .flatMap(order -> order.getProducts().stream())
    .map(Product::getName)
    .distinct()
    .collect(Collectors.toList());

统计分析

java
// 按部门统计工资
Map<String, DoubleSummaryStatistics> salaryStats = employees.stream()
    .collect(Collectors.groupingBy(
        Employee::getDepartment,
        Collectors.summarizingDouble(Employee::getSalary)
    ));