Java中的随机加权选择

yos*_*osi 58 java random double

我想从一个集合中选择一个随机项目,但是选择任何项目的机会应该与相关的权重成比例

示例输入:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1
Run Code Online (Sandbox Code Playgroud)

所以,如果我有4个可能的项目,那么获得任何一个没有权重的项目的机会将是1/4.

在这种情况下,用户获得痛苦之剑的可能性应该是三角剑的10倍.

如何在Java中进行加权随机选择?

Pet*_*rey 104

我会使用NavigableMap

public class RandomCollection<E> {
    private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
    private final Random random;
    private double total = 0;

    public RandomCollection() {
        this(new Random());
    }

    public RandomCollection(Random random) {
        this.random = random;
    }

    public RandomCollection<E> add(double weight, E result) {
        if (weight <= 0) return this;
        total += weight;
        map.put(total, result);
        return this;
    }

    public E next() {
        double value = random.nextDouble() * total;
        return map.higherEntry(value).getValue();
    }
}
Run Code Online (Sandbox Code Playgroud)

假设我有动物狗,猫,马的列表,概率分别为40%,35%,25%

RandomCollection<String> rc = new RandomCollection<>()
                              .add(40, "dog").add(35, "cat").add(25, "horse");

for (int i = 0; i < 10; i++) {
    System.out.println(rc.next());
} 
Run Code Online (Sandbox Code Playgroud)

  • 感谢彼得的回答!效果很好。如果有人 - 像我一样 - 想知道 `if (weight &lt;= 0 return this;`),它有一个重要的目的。没有它,如果你重用该示例并调用 `.add(0, "lizard")` *在*调用`.add(25, "horse")`之后,这将覆盖地图中“horse”的条目,因为新调用“put(total, result)”具有*相同*的总重量与之前的条目一样,因此将“马”替换为“蜥蜴”,即使它应该有 0% 的机会被选中。 (3认同)

Arn*_*sch 23

您将找不到这种问题的框架,因为所请求的功能只不过是一个简单的功能.做这样的事情:

interface Item {
    double getWeight();
}

class RandomItemChooser {
    public Item chooseOnWeight(List<Item> items) {
        double completeWeight = 0.0;
        for (Item item : items)
            completeWeight += item.getWeight();
        double r = Math.random() * completeWeight;
        double countWeight = 0.0;
        for (Item item : items) {
            countWeight += item.getWeight();
            if (countWeight >= r)
                return item;
        }
        throw new RuntimeException("Should never be shown.");
    }
}
Run Code Online (Sandbox Code Playgroud)


kdk*_*eck 20

Apache Commons现在有一个类:EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution<>(itemWeights).sample();
Run Code Online (Sandbox Code Playgroud)

这里itemWeightsList<Pair<Item, Double>>,像(假设项目接口阿恩的答案):

final List<Pair<Item, Double>> itemWeights = Collections.newArrayList();
for (Item i: itemSet) {
    itemWeights.add(new Pair(i, i.getWeight()));
}
Run Code Online (Sandbox Code Playgroud)

或者在Java 8中:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());
Run Code Online (Sandbox Code Playgroud)

注意:Item这里需要,而Pair不是org.apache.commons.math3.util.Pair.

  • 这确实应该在答案列表中更高......为什么要重新发明轮子?此外,“EnumeratedDistribution”允许一次选择多个样本,这非常简洁。 (2认同)
  • Commons Math3 现在不受支持。`EnumeratedDistribution` 的功能已移至 [DiscreteProbabilityCollectionSampler](https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler。 html)在 [Commons RNG](https://commons.apache.org/proper/commons-rng/) 库中。 (2认同)

Oli*_*ire 6

使用别名方法

如果要滚动很多次(例如在游戏中),则应使用别名方法。

实际上,以下代码是这种别名方法的相当长的实现。但这是由于初始化部分。元素的检索非常快(请参阅nextapplyAsInt不会循环的方法)。

用法

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;

Random random = new Random();

RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);
Run Code Online (Sandbox Code Playgroud)

实作

此实现:

  • 使用Java 8 ;
  • 被设计为尽可能快(至少,我尝试使用微基准测试来做到这一点);
  • 是完全线程安全的Random在每个线程中保留一个以获取最大性能,请使用ThreadLocalRandom?);
  • 获取O(1)中的元素,这与您在Internet或StackOverflow上通常会在O(n)或O(log(n))中运行的不同;
  • 使项目不受其权重影响,因此可以在不同的上下文中为项目分配各种权重。

无论如何,这是代码。(请注意,我维护了该类的最新版本。)

import static java.util.Objects.requireNonNull;

import java.util.*;
import java.util.function.*;

public final class RandomSelector<T> {

  public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
      throws IllegalArgumentException {
    requireNonNull(elements, "elements must not be null");
    requireNonNull(weighter, "weighter must not be null");
    if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }

    // Array is faster than anything. Use that.
    int size = elements.size();
    T[] elementArray = elements.toArray((T[]) new Object[size]);

    double totalWeight = 0d;
    double[] discreteProbabilities = new double[size];

    // Retrieve the probabilities
    for (int i = 0; i < size; i++) {
      double weight = weighter.applyAsDouble(elementArray[i]);
      if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
      discreteProbabilities[i] = weight;
      totalWeight += weight;
    }
    if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }

    // Normalize the probabilities
    for (int i = 0; i < size; i++) {
      discreteProbabilities[i] /= totalWeight;
    }
    return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
  }

  private final T[] elements;
  private final ToIntFunction<Random> selection;

  private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
    this.elements = elements;
    this.selection = selection;
  }

  public T next(Random random) {
    return elements[selection.applyAsInt(random)];
  }

  private static class RandomWeightedSelection implements ToIntFunction<Random> {
    // Alias method implementation O(1)
    // using Vose's algorithm to initialize O(n)

    private final double[] probabilities;
    private final int[] alias;

    RandomWeightedSelection(double[] probabilities) {
      int size = probabilities.length;

      double average = 1.0d / size;
      int[] small = new int[size];
      int smallSize = 0;
      int[] large = new int[size];
      int largeSize = 0;

      // Describe a column as either small (below average) or large (above average).
      for (int i = 0; i < size; i++) {
        if (probabilities[i] < average) {
          small[smallSize++] = i;
        } else {
          large[largeSize++] = i;
        }
      }

      // For each column, saturate a small probability to average with a large probability.
      while (largeSize != 0 && smallSize != 0) {
        int less = small[--smallSize];
        int more = large[--largeSize];
        probabilities[less] = probabilities[less] * size;
        alias[less] = more;
        probabilities[more] += probabilities[less] - average;
        if (probabilities[more] < average) {
          small[smallSize++] = more;
        } else {
          large[largeSize++] = more;
        }
      }

      // Flush unused columns.
      while (smallSize != 0) {
        probabilities[small[--smallSize]] = 1.0d;
      }
      while (largeSize != 0) {
        probabilities[large[--largeSize]] = 1.0d;
      }
    }

    @Override public int applyAsInt(Random random) {
      // Call random once to decide which column will be used.
      int column = random.nextInt(probabilities.length);

      // Call random a second time to decide which will be used: the column or the alias.
      if (random.nextDouble() < probabilities[column]) {
        return column;
      } else {
        return alias[column];
      }
    }
  }
}
Run Code Online (Sandbox Code Playgroud)


小智 6

139

有一个简单的算法可以随机选择一个项目,其中项目具有单独的权重:

  1. 计算所有权重的总和

  2. 选择一个大于等于 0 且小于权重之和的随机数

  3. 一次检查一件物品,从你的随机数中减去它们的重量,直到你得到随机数小于该物品重量的物品


小智 5

public class RandomCollection<E> {
  private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  private double total = 0;

  public void add(double weight, E result) {
    if (weight <= 0 || map.containsValue(result))
      return;
    total += weight;
    map.put(total, result);
  }

  public E next() {
    double value = ThreadLocalRandom.current().nextDouble() * total;
    return map.ceilingEntry(value).getValue();
  }
}
Run Code Online (Sandbox Code Playgroud)