按权重随机选择
Tips
题目类型: 前缀和 + 二分
题目
给你一个下标从 0 开始的正整数数组 w, 其中 w[i] 代表第 i 个下标的权重.
请你实现一个函数 pickIndex
, 它可以随机地从范围 [0, w.length - 1]
内(含 0
和 w.length - 1
)选出并返回一个下标. 选取下标 i
的概率 为 w[i] / sum(w)
.
例如, 对于 w = [1, 3
], 挑选下标 0
的概率为 1 / (1 + 3) = 0.25
(即, 25%
), 而选取下标 1
的概率为 3 / (1 + 3) = 0.75(即, 75%)
.
示例
Solution solution = new Solution([1, 3]);
solution.pickIndex(); // 返回 1, 返回下标 1, 返回该下标概率为 3/4.
solution.pickIndex(); // 返回 1
solution.pickIndex(); // 返回 1
solution.pickIndex(); // 返回 1
solution.pickIndex(); // 返回 0, 返回下标 0, 返回该下标概率为 1/4.
由于这是一个随机问题, 允许多个答案, 因此下列输出都可以被认为是正确的:
[null,1,1,1,1,0]
[null,1,1,1,1,1]
[null,1,1,1,0,0]
[null,1,1,1,0,1]
[null,1,0,1,0,0]
......
诸若此类.
题解
假设 w = [3, 1, 2, 4]
, 那么 total
就是 10
, 题目要求选取下标 i
的概率 为 w[i] / sum(w)
, 也就是如果选 3
, 那应该抽取概率占 3 / 10
; 如果选 1
, 那应该抽取概率占 3 / 10
, 以此类推...
所以可以这么想, 把数组 w
变成 [[1, 3], [4, 4], [5, 6], [7, 10]]
, 它的意思就是说第一个元素 3
占总体的三份, 那就给它 [1, 3]
这个区间, 正好这个区间有三个数字; 第二个元素 1
占总体的一份, 那就给它 [4, 4]
这个区间, 正好这个区间有一个数字, 以此类推...
我们观察这个二维数组中每个数组的第二个元素, 也就是 [3, 4, 6, 10]
, 其实是个前缀和序列, 即 preSum[i] = preSum[i - 1] + w[i]
因此只要随机数 x
最小的满足 x <= preSum[i]
, 就保证在相应的区间里, 也就保证了选取下标 i
的概率 为 w[i] / sum(w)
. 由于 preSum
是个有序数组, 用二分找即可.
- JavaScript
- Rust
/**
* @param {number[]} w
*/
var Solution = function (w) {
const n = w.length
const preSum = new Array(n).fill(0)
preSum[0] = w[0]
for (let i = 1; i < n; i++) {
preSum[i] = preSum[i - 1] + w[i]
}
this.preSum = preSum
this.total = w.reduce((acc, val) => acc + val)
}
/**
* @return {number}
*/
Solution.prototype.pickIndex = function () {
const x = Math.floor(Math.random() * this.total) + 1
return this.binarySearch(x)
}
Solution.prototype.binarySearch = function (val) {
let low = 0,
high = this.preSum.length - 1
while (low < high) {
const mid = low + Math.floor((high - low) / 2)
if (this.preSum[mid] < val) {
low = mid + 1
} else {
high = mid
}
}
return low
}
/**
* Your Solution object will be instantiated and called as such:
* var obj = new Solution(w)
* var param_1 = obj.picki()
*/
use rand::Rng;
struct Solution {
total: i32,
pre_sum: Vec<i32>,
}
/**
* `&self` means the method takes an immutable reference.
* If you need a mutable reference, change it to `&mut self` instead.
*/
impl Solution {
fn new(w: Vec<i32>) -> Self {
let n = w.len();
let mut pre_sum = vec![0; n];
pre_sum[0] = w[0];
for i in 1..n {
pre_sum[i] = pre_sum[i - 1] + w[i];
}
Solution {
total: w.iter().sum(),
pre_sum,
}
}
fn pick_index(&self) -> i32 {
let x = rand::thread_rng().gen_range(0..self.total) + 1;
self.pre_sum
.binary_search(&x)
.unwrap_or_else(|e| self.pre_sum[e].try_into().unwrap()) as i32
}
}
/**
* Your Solution object will be instantiated and called as such:
* let obj = Solution::new(w);
* let ret_1: i32 = obj.pick_index();
*/