Skip to main content

寻找两个正序数组的中位数

Tips

题目类型: 二分查找

题目

给定两个大小分别为 mn 的正序数组 nums1nums2. 请你找出并返回这两个正序数组的中位数. 要求使用 O(log(m + n)) 的时间复杂度.

提示:
  • nums1.length == m
  • nums2.length == n
  • 0 <= m <= 1000
  • 0 <= n <= 1000
  • 1 <= m + n <= 2000
  • -10⁶ <= nums1[i], nums2[i] <= 10⁶
示例
输入: nums1 = [1, 3], nums2 = [2]
输出: 2
解释: 合并后的数组 = [1, 2, 3], 中位数 2
输入: nums1 = [1, 2], nums2 = [3, 4]
输出: 2.50000
解释: 合并数组 = [1, 2, 3, 4], 中位数 (2 + 3) / 2 = 2.5

题解

由于题目要求 O(log(m + n)) 的时间复杂度, 那就要往二分查找上想. 题目要求求中位数, 其实就是求第 k 小的数.

核心思路: 每次排除掉 k/2 个元素.

假设我们要找第 k 小的数, 我们可以比较两个数组中第 k/2 个元素(即下标为 k/2 - 1):

  • 如果 nums1[k/2 - 1] < nums2[k/2 - 1], 那么 nums1 的前 k/2 个元素一定都在前 k 小的元素范围内, 且不包含第 k 小的那个数.
    • 为什么? 假设 nums1 的第 k/2 个元素比 nums2 的第 k/2 个元素小, 即使 nums2 的前 k/2 - 1 个元素都比 nums1 的第 k/2 个元素小, 这两部分加起来也只有 k/2 + k/2 - 1 = k - 1 个元素. 所以 nums1 的第 k/2 个元素最多只能是第 k-1 小的数.
    • 因此, 我们可以放心地把 nums1 的前 k/2 个元素"逻辑上"移除(通过移动起始索引).
  • 反之, 如果 nums1[k/2 - 1] > nums2[k/2 - 1], 则可以排除 nums2 的前 k/2 个元素.

举例说明:

假设 nums1 = [1, 3, 4, 9], nums2 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 我们要找第 7 小的数字 (k=7).

  1. 第一轮: k=7, k/2=3. 比较 nums1[2] (即 4) 和 nums2[2] (即 3).

    • 因为 3 < 4, 所以 nums2 的前 3 个元素 [1, 2, 3] 肯定不是第 7 小的数, 可以排除.
    • nums2 逻辑上变为 [4, 5, 6, 7, 8, 9, 10].
    • 我们排除了 3 个元素, 所以现在要找第 7 - 3 = 4 小的数.
  2. 第二轮: k=4, k/2=2. 比较 nums1[1] (即 3) 和 nums2 新起点的第 2 个元素 (即 5).

    • 注意: nums2 的新起点是原数组下标 3 (即值 4). 所以比较的是 nums1[1] (3) 和 nums2[3+2-1] (5).
    • 因为 3 < 5, 所以 nums1 的前 2 个元素 [1, 3] 可以排除.
    • nums1 逻辑上变为 [4, 9].
    • 我们又排除了 2 个元素, 所以现在要找第 4 - 2 = 2 小的数.
  3. 第三轮: k=2, k/2=1. 比较 nums1 新起点第 1 个元素 (4) 和 nums2 新起点第 1 个元素 (4).

    • 值相等, 我们可以任意排除一方. 假设排除 nums1 的.
    • nums1 逻辑上变为 [9].
    • k 变为 2 - 1 = 1.
  4. 终止: k=1. 此时只要比较两个数组当前起始位置的元素, 较小的那个就是答案.

    • nums1 当前是 [9], nums2 当前是 [4, 5, ...].
    • min(9, 4) = 4.
    • 所以第 7 小的数是 4.

在代码实现中, 我们使用 start1start2 指针来标记逻辑上的数组起始位置, 避免真正的数组切片操作.

/**
* @param {number[]} nums1
* @param {number[]} nums2
* @return {number}
*/
var findMedianSortedArrays = function (nums1, nums2) {
const m = nums1.length
const n = nums2.length

// 由于 k 的意义是第 k 个最小的数, 如果不加一就取成索引了.
const left = Math.floor((m + n + 1) / 2)
const right = Math.floor((m + n + 2) / 2)

// 求 left 与 right 之和是为了打平奇数长度和偶数长度, 由于最后算了两次, 所以最终结果要除以 2
return (
(findkth(nums1, 0, m - 1, nums2, 0, n - 1, left) +
findkth(nums1, 0, m - 1, nums2, 0, n - 1, right)) /
2
)
}

var findkth = function (arr1, start1, end1, arr2, start2, end2, k) {
// 获取两个数组的长度
const m = end1 - start1 + 1
const n = end2 - start2 + 1

// 如果数组 a 空了, 那最终答案就在数组 b 中寻找, 即 arr2[start2 + k - 1]
if (m === 0) return arr2[start2 + k - 1]

// 如果数组 b 空了, 那最终答案就在数组 a 中寻找, 即 arr1[start1 + k - 1]
if (n === 0) return arr1[start1 + k - 1]

// 如果 k === 1, 就说明两个数组的第一个元素中, 最小的那个就是答案
if (k === 1) return Math.min(arr1[start1], arr2[start2])

// 每次让数组长度(m 或 n) 与 Math.floor(k / 2) 比较, 取最小的那个
// 这样的目的是如果 Math.floor(k / 2) 比数组的长度大了, 如果去 Math.floor(k / 2) 的话, 数组就越界了
// 因此需要取两者中最小的, 就保证 i 或者 j 在这种情况就指到了数组的最后一个元素(下一次递归时这个数组长度就为 0 了)
const i = start1 + Math.min(m, Math.floor(k / 2)) - 1
const j = start2 + Math.min(n, Math.floor(k / 2)) - 1

if (arr1[i] > arr2[j]) {
// 如果 arr1[i] > arr2[j], 说明要把 arr2[j] 前 j 个干掉, 即把 start2 设为 j + 1,
// 此外由于数组 2 被削减了 j - start2 + 1 个, 所以 k 变成 k - (j - start2 + 1)
return findkth(arr1, start1, end1, arr2, j + 1, end2, k - (j - start2 + 1))
} else {
// 同理
return findkth(arr1, i + 1, end1, arr2, start2, end2, k - (i - start1 + 1))
}
}
  • 时间复杂度: O(log(m + n))
  • 空间复杂度: O(1), 虽然用到了递归, 但属于尾递归优化.