寻找两个正序数组的中位数
Tips
题目类型: 二分查找
题目
给定两个大小分别为 m 和 n 的正序数组 nums1 和 nums2. 请你找出并返回这两个正序数组的中位数. 要求使用 O(log(m + n)) 的时间复杂度.
提示:
nums1.length == mnums2.length == n0 <= m <= 10000 <= n <= 10001 <= m + n <= 2000-10⁶ <= nums1[i], nums2[i] <= 10⁶
示例
输入: nums1 = [1, 3], nums2 = [2]
输出: 2
解释: 合并后的数组 = [1, 2, 3], 中位数 2
输入: nums1 = [1, 2], nums2 = [3, 4]
输出: 2.50000
解释: 合并数组 = [1, 2, 3, 4], 中位数 (2 + 3) / 2 = 2.5
题解
- JavaScript - 二分查找
- Rust
由于题目要求 O(log(m + n)) 的时间复杂度, 那就要往二分查找上想. 题目要求求中位数, 其实就是求第 k 小的数.
核心思路: 每次排除掉 k/2 个元素.
假设我们要找第 k 小的数, 我们可以比较两个数组中第 k/2 个元素(即下标为 k/2 - 1):
- 如果
nums1[k/2 - 1] < nums2[k/2 - 1], 那么nums1的前k/2个元素一定都在前k小的元素范围内, 且不包含第k小的那个数.- 为什么? 假设
nums1的第k/2个元素比nums2的第k/2个元素小, 即使nums2的前k/2 - 1个元素都比nums1的第k/2个元素小, 这两部分加起来也只有k/2 + k/2 - 1 = k - 1个元素. 所以nums1的第k/2个元素最多只能是第k-1小的数. - 因此, 我们可以放心地把
nums1的前k/2个元素"逻辑上"移除(通过移动起始索引).
- 为什么? 假设
- 反之, 如果
nums1[k/2 - 1] > nums2[k/2 - 1], 则可以排除nums2的前k/2个元素.
举例说明:
假设 nums1 = [1, 3, 4, 9], nums2 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 我们要找第 7 小的数字 (k=7).
-
第一轮:
k=7,k/2=3. 比较nums1[2](即 4) 和nums2[2](即 3).- 因为
3 < 4, 所以nums2的前 3 个元素[1, 2, 3]肯定不是第 7 小的数, 可以排除. nums2逻辑上变为[4, 5, 6, 7, 8, 9, 10].- 我们排除了 3 个元素, 所以现在要找第
7 - 3 = 4小的数.
- 因为
-
第二轮:
k=4,k/2=2. 比较nums1[1](即 3) 和nums2新起点的第 2 个元素 (即 5).- 注意:
nums2的新起点是原数组下标 3 (即值 4). 所以比较的是nums1[1](3) 和nums2[3+2-1](5). - 因为
3 < 5, 所以nums1的前 2 个元素[1, 3]可以排除. nums1逻辑上变为[4, 9].- 我们又排除了 2 个元素, 所以现在要找第
4 - 2 = 2小的数.
- 注意:
-
第三轮:
k=2,k/2=1. 比较nums1新起点第 1 个元素 (4) 和nums2新起点第 1 个元素 (4).- 值相等, 我们可以任意排除一方. 假设排除
nums1的. nums1逻辑上变为[9].k变为2 - 1 = 1.
- 值相等, 我们可以任意排除一方. 假设排除
-
终止:
k=1. 此时只要比较两个数组当前起始位置的元素, 较小的那个就是答案.nums1当前是[9],nums2当前是[4, 5, ...].min(9, 4) = 4.- 所以第 7 小的数是
4.
在代码实现中, 我们使用 start1 和 start2 指针来标记逻辑上的数组起始位置, 避免真正的数组切片操作.
/**
* @param {number[]} nums1
* @param {number[]} nums2
* @return {number}
*/
var findMedianSortedArrays = function (nums1, nums2) {
const m = nums1.length
const n = nums2.length
// 由于 k 的意义是第 k 个最小的数, 如果不加一就取成索引了.
const left = Math.floor((m + n + 1) / 2)
const right = Math.floor((m + n + 2) / 2)
// 求 left 与 right 之和是为了打平奇数长度和偶数长度, 由于最后算了两次, 所以最终结果要除以 2
return (
(findkth(nums1, 0, m - 1, nums2, 0, n - 1, left) +
findkth(nums1, 0, m - 1, nums2, 0, n - 1, right)) /
2
)
}
var findkth = function (arr1, start1, end1, arr2, start2, end2, k) {
// 获取两个数组的长度
const m = end1 - start1 + 1
const n = end2 - start2 + 1
// 如果数组 a 空了, 那最终答案就在数组 b 中寻找, 即 arr2[start2 + k - 1]
if (m === 0) return arr2[start2 + k - 1]
// 如果数组 b 空了, 那最终答案就在数组 a 中寻找, 即 arr1[start1 + k - 1]
if (n === 0) return arr1[start1 + k - 1]
// 如果 k === 1, 就说明两个数组的第一个元素中, 最小的那个就是答案
if (k === 1) return Math.min(arr1[start1], arr2[start2])
// 每次让数组长度(m 或 n) 与 Math.floor(k / 2) 比较, 取最小的那个
// 这样的目的是如果 Math.floor(k / 2) 比数组的长度大了, 如果去 Math.floor(k / 2) 的话, 数组就越界了
// 因此需要取两者中最小的, 就保证 i 或者 j 在这种情况就指到了数组的最后一个元素(下一次递归时这个数组长度就为 0 了)
const i = start1 + Math.min(m, Math.floor(k / 2)) - 1
const j = start2 + Math.min(n, Math.floor(k / 2)) - 1
if (arr1[i] > arr2[j]) {
// 如果 arr1[i] > arr2[j], 说明要把 arr2[j] 前 j 个干掉, 即把 start2 设为 j + 1,
// 此外由于数组 2 被削减了 j - start2 + 1 个, 所以 k 变成 k - (j - start2 + 1)
return findkth(arr1, start1, end1, arr2, j + 1, end2, k - (j - start2 + 1))
} else {
// 同理
return findkth(arr1, i + 1, end1, arr2, start2, end2, k - (i - start1 + 1))
}
}
- 时间复杂度:
O(log(m + n)) - 空间复杂度:
O(1), 虽然用到了递归, 但属于尾递归优化.
use std::cmp;
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let m = nums1.len();
let n = nums2.len();
let left = (m + n + 1) / 2;
let right = (m + n + 2) / 2;
(find_kth(&nums1, 0, m - 1, &nums2, 0, n - 1, left)
+ find_kth(&nums1, 0, m - 1, &nums2, 0, n - 1, right))
/ 2.0
}
fn find_kth(
arr1: &Vec<i32>,
start1: usize,
end1: usize,
arr2: &Vec<i32>,
start2: usize,
end2: usize,
k: usize,
) -> f64 {
let m = end1 - start1 + 1;
let n = end2 - start2 + 1;
if m == 0 {
return arr2[start2 + k - 1].into();
}
if n == 0 {
return arr1[start1 + k - 1].into();
}
if k == 1 {
return cmp::min(arr1[start1], arr2[start2]).into();
}
let i = start1 + cmp::min(m, k / 2) - 1;
let j = start2 + cmp::min(n, k / 2) - 1;
if arr1[i] > arr2[j] {
return find_kth(arr1, start1, end1, arr2, j + 1, end2, k - (j - start2 + 1));
} else {
return find_kth(arr1, i + 1, end1, arr2, start2, end2, k - (i - start1 + 1));
}
}