寻找两个正序数组的中位数
题目类型: 二分查找
题目
给定两个大小分别为 m
和 n
的正序数组 nums1
和 nums2
. 请你找出并返 回这两个正序数组的中位数. 要求使用 O(log(m + n))
的时间复杂度.
nums1.length == m
nums2.length == n
0 <= m <= 1000
0 <= n <= 1000
1 <= m + n <= 2000
-10⁶ <= nums1[i], nums2[i] <= 10⁶
输入: nums1 = [1, 3], nums2 = [2]
输出: 2
解释: 合并后的数组 = [1, 2, 3], 中位数 2
输入: nums1 = [1, 2], nums2 = [3, 4]
输出: 2.50000
解释: 合并数组 = [1, 2, 3, 4], 中位数 (2 + 3) / 2 = 2.5
题解
- JavaScript - 朴素解法
- JavaScript - 二分查找
- Rust
最朴素的解法是将两个数组按照从小到大合并到一起, 然后取中值, 需要考虑合并后的数组长度是奇数个还是偶数个.
/**
* @param {number[]} nums1
* @param {number[]} nums2
* @return {number}
*/
var findMedianSortedArrays = function (nums1, nums2) {
const m = nums1.length
const n = nums2.length
const total = m + n
const isOdd = total % 2 === 1
const mid = (total / 2) | 0
const arr = []
let i = 0,
j = 0
while (i + j <= mid) {
// 如果 i 已经走到头了, 需要把 nums2 剩下的元素放到 arr 最后
if (i === m) {
arr.push(...nums2.slice(j))
break
}
// 如果 j 已经走到头了, 需要把 nums1 剩下的元素放到 arr 最后
if (j === n) {
arr.push(...nums1.slice(i))
break
}
if (nums1[i] < nums2[j]) {
arr.push(nums1[i++])
} else {
arr.push(nums2[j++])
}
}
return isOdd ? arr[mid] : (arr[mid] + arr[mid - 1]) / 2
}
- 时间复杂度:
O(m + n)
- 空间复杂度:
O(m + n)
由于题目要求 O(log(m + n))
的时间复杂度, 那就要往二分查找上想. 题目要求求中位数, 其实就是求第 k
小的数. 举一个例子, 数组 a 是 [1, 3, 4, 9]
,
数组 b 是 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
, 我们找第 7 小的数字.
我们找第 k / 2
个数, 如果 k
是奇数, 向下取整, 因此得到的 k / 2
是 3
. 此时数组 a 第三个数是 4
, 数组 b 第三个数是 3
, 由于数组 b 的第三个数小于数组 a 第三个数,
所以数组 b 的前三个数, 也就是 1
, 2
, 3
一定不是第 7
小的数字.
此时, 我们去掉数组 b 的前三个数, 变成 [4, 5, 6, 7, 8, 9, 10]
, 并且由于去掉了三个数, 那 k
也相应的变成了 4
, 此时 k / 2
为 2
.
同样的方法, 数组 a 的第二个数为 3
, 数组 b 的第二个数为 5
, 所以此时把数组 a 的前两个数去掉, 此时数组 a 变成 [4, 9]
, 那 k
也相应的变成了 2
, 此时 k / 2
为 1
.
由于此时两个数组 a, b 的第一个数都是 4, 所以随便去掉一个就行, 比如我们把数组 b 的第一个元素干掉, 此时数组 b 变成 [5, 6, 7, 8, 9, 10]
, 那 k
也相应的变成了 1
.
这时候我们就可以停下来了, 因为 k
已经是 1
了, 我们就比较两个数组第一个元素哪个最小, 哪个就是第 k
小的元素, 所以最终答案为 4
.
当然这里还有一种特殊情况, 比如数组 a 在递归的过程中空了, 但此时 k
是 5
, 那就找数组 b 的第 5
个就好了.
当然在代码实践上, 我们并不是真正把数组前面的一截去掉, 而是通过指针来模拟起始位置.
/**
* @param {number[]} nums1
* @param {number[]} nums2
* @return {number}
*/
var findMedianSortedArrays = function (nums1, nums2) {
const m = nums1.length
const n = nums2.length
// 由于 k 的意义是第 k 个最小的数, 如果不加一就取成索引了.
const left = Math.floor((m + n + 1) / 2)
const right = Math.floor((m + n + 2) / 2)
// 求 left 与 right 之和是为了打平奇数长度和偶数长度, 由于最后算了两次, 所以最终结果要除以 2
return (
(findkth(nums1, 0, m - 1, nums2, 0, n - 1, left) +
findkth(nums1, 0, m - 1, nums2, 0, n - 1, right)) /
2
)
}
var findkth = function (arr1, start1, end1, arr2, start2, end2, k) {
// 获取两个数组的长度
const m = end1 - start1 + 1
const n = end2 - start2 + 1
// 如果数组 a 空了, 那最终答案就在数组 b 中寻找, 即 arr2[start2 + k - 1]
if (m === 0) return arr2[start2 + k - 1]
// 如果数组 b 空了, 那最终答案就在数组 a 中寻找, 即 arr1[start1 + k - 1]
if (n === 0) return arr1[start1 + k - 1]
// 如果 k === 1, 就说明两个数组 的第一个元素中, 最小的那个就是答案
if (k === 1) return Math.min(arr1[start1], arr2[start2])
// 每次让数组长度(m 或 n) 与 Math.floor(k / 2) 比较, 取最小的那个
// 这样的目的是如果 Math.floor(k / 2) 比数组的长度大了, 如果去 Math.floor(k / 2) 的话, 数组就越界了
// 因此需要取两者中最小的, 就保证 i 或者 j 在这种情况就指到了数组的最后一个元素(下一次递归时这个数组长度就为 0 了)
const i = start1 + Math.min(m, Math.floor(k / 2)) - 1
const j = start2 + Math.min(n, Math.floor(k / 2)) - 1
if (arr1[i] > arr2[j]) {
// 如果 arr1[i] > arr2[j], 说明要把 arr2[j] 前 j 个干掉, 即把 start2 设为 j + 1,
// 此外由于数组 2 被削减了 j - start2 + 1 个, 所以 k 变成 k - (j - start2 + 1)
return findkth(arr1, start1, end1, arr2, j + 1, end2, k - (j - start2 + 1))
} else {
// 同理
return findkth(arr1, i + 1, end1, arr2, start2, end2, k - (i - start1 + 1))
}
}
- 时间复杂度:
O(log(m + n))
- 空间复杂度:
O(1)
, 虽然用到了递归, 但属于尾递归优化.
use std::cmp;
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let m = nums1.len();
let n = nums2.len();
let left = (m + n + 1) / 2;
let right = (m + n + 2) / 2;
(find_kth(&nums1, 0, m - 1, &nums2, 0, n - 1, left)
+ find_kth(&nums1, 0, m - 1, &nums2, 0, n - 1, right))
/ 2.0
}
fn find_kth(
arr1: &Vec<i32>,
start1: usize,
end1: usize,
arr2: &Vec<i32>,
start2: usize,
end2: usize,
k: usize,
) -> f64 {
let m = end1 - start1 + 1;
let n = end2 - start2 + 1;
if m == 0 {
return arr2[start2 + k - 1].into();
}
if n == 0 {
return arr1[start1 + k - 1].into();
}
if k == 1 {
return cmp::min(arr1[start1], arr2[start2]).into();
}
let i = start1 + cmp::min(m, k / 2) - 1;
let j = start2 + cmp::min(n, k / 2) - 1;
if arr1[i] > arr2[j] {
return find_kth(arr1, start1, end1, arr2, j + 1, end2, k - (j - start2 + 1));
} else {
return find_kth(arr1, i + 1, end1, arr2, start2, end2, k - (i - start1 + 1));
}
}