如何使用 argmax 方法扩展 Scala 集合?

发布于 2024-09-06 00:01:55 字数 120 浏览 8 评论 0原文

我想在所有有意义的集合中添加 argMax 方法。 怎么做呢?使用隐式?

I would like to add to all collections where it makes sense, an argMax method.
How to do it? Use implicits?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(8

琉璃繁缕 2024-09-13 00:01:55

在 Scala 2.8 上,这是有效的:

val list = List(1, 2, 3)
def f(x: Int) = -x
val argMax = list max (Ordering by f)

正如 mkneissl 所指出的,这不会返回以下集合最大点数。这是一个替代实现,它会尝试减少对 f 的调用次数。如果对 f 的调用并不重要,请参阅 mkneissl 的回答。另请注意,他的答案是柯里化的,这提供了优越的类型推断。

def argMax[A, B: Ordering](input: Iterable[A], f: A => B) = {
  val fList = input map f
  val maxFList = fList.max
  input.view zip fList filter (_._2 == maxFList) map (_._1) toSet
}

scala> argMax(-2 to 2, (x: Int) => x * x)
res15: scala.collection.immutable.Set[Int] = Set(-2, 2)

On Scala 2.8, this works:

val list = List(1, 2, 3)
def f(x: Int) = -x
val argMax = list max (Ordering by f)

As pointed by mkneissl, this does not return the set of maximum points. Here's an alternate implementation that does, and tries to reduce the number of calls to f. If calls to f don't matter that much, see mkneissl's answer. Also, note that his answer is curried, which provides superior type inference.

def argMax[A, B: Ordering](input: Iterable[A], f: A => B) = {
  val fList = input map f
  val maxFList = fList.max
  input.view zip fList filter (_._2 == maxFList) map (_._1) toSet
}

scala> argMax(-2 to 2, (x: Int) => x * x)
res15: scala.collection.immutable.Set[Int] = Set(-2, 2)
无声无音无过去 2024-09-13 00:01:55

argmax 函数(据我从 Wikipedia 理解)

def argMax[A,B](c: Traversable[A])(f: A=>B)(implicit o: Ordering[B]): Traversable[A] = {
  val max = (c map f).max(o)
  c filter { f(_) == max }
}

如果你真的想要,你可以pimp它到集合上

implicit def enhanceWithArgMax[A](c: Traversable[A]) = new {
  def argMax[B](f: A=>B)(implicit o: Ordering[B]): Traversable[A] = ArgMax.argMax(c)(f)(o)
}

并像这样使用它

val l = -2 to 2
assert (argMax(l)(x => x*x) == List(-2,2))
assert (l.argMax(x => x*x) == List(-2,2))

(Scala 2.8)

The argmax function (as I understand it from Wikipedia)

def argMax[A,B](c: Traversable[A])(f: A=>B)(implicit o: Ordering[B]): Traversable[A] = {
  val max = (c map f).max(o)
  c filter { f(_) == max }
}

If you really want, you can pimp it onto the collections

implicit def enhanceWithArgMax[A](c: Traversable[A]) = new {
  def argMax[B](f: A=>B)(implicit o: Ordering[B]): Traversable[A] = ArgMax.argMax(c)(f)(o)
}

and use it like this

val l = -2 to 2
assert (argMax(l)(x => x*x) == List(-2,2))
assert (l.argMax(x => x*x) == List(-2,2))

(Scala 2.8)

阪姬 2024-09-13 00:01:55

是的,通常的方法是使用“为我的图书馆拉皮条”图案来装饰您的收藏。例如(NB只是作为说明,并不意味着是一个正确或有效的示例):


trait PimpedList[A] {
   val l: List[A]

  //example argMax, not meant to be correct
  def argMax[T <% Ordered[T]](f:T => T) = {error("your definition here")}
}

implicit def toPimpedList[A](xs: List[A]) = new PimpedList[A] { 
  val l = xs
}

scala> def f(i:Int):Int = 10
f: (i: Int) Int

scala> val l = List(1,2,3)
l: List[Int] = List(1, 2, 3)

scala> l.argMax(f)
java.lang.RuntimeException: your definition here
    at scala.Predef$.error(Predef.scala:60)
    at PimpedList$class.argMax(:12)
        //etc etc...

Yes, the usual way would be to use the 'pimp my library' pattern to decorate your collection. For example (N.B. just as illustration, not meant to be a correct or working example):


trait PimpedList[A] {
   val l: List[A]

  //example argMax, not meant to be correct
  def argMax[T <% Ordered[T]](f:T => T) = {error("your definition here")}
}

implicit def toPimpedList[A](xs: List[A]) = new PimpedList[A] { 
  val l = xs
}

scala> def f(i:Int):Int = 10
f: (i: Int) Int

scala> val l = List(1,2,3)
l: List[Int] = List(1, 2, 3)

scala> l.argMax(f)
java.lang.RuntimeException: your definition here
    at scala.Predef$.error(Predef.scala:60)
    at PimpedList$class.argMax(:12)
        //etc etc...
清欢 2024-09-13 00:01:55

又好又容易? :

val l = List(1,0,10,2)
l.zipWithIndex.maxBy(x => x._1)._2

Nice and easy ? :

val l = List(1,0,10,2)
l.zipWithIndex.maxBy(x => x._1)._2
隔岸观火 2024-09-13 00:01:55

您可以使用 Pimp my Library 模式。您可以通过定义隐式转换函数来实现此目的。例如,我有一个类 Vector3 来表示 3D 矢量:

class Vector3 (val x: Float, val y: Float, val z: Float)

假设我希望能够通过编写以下内容来缩放矢量:2.5f * v。当然,我不能直接向类 Float 添加 * 方法,但我可以提供一个隐式转换函数,如下所示:

implicit def scaleVector3WithFloat(f: Float) = new {
    def *(v: Vector3) = new Vector3(f * v.x, f * v.y, f * v.z)
}

请注意,这会返回一个结构类型的对象(包含 * 方法的 new { ... } 构造)。

我还没有测试过,但我想你可以这样做:

implicit def argMaxImplicit[A](t: Traversable[A]) = new {
    def argMax() = ...
}

You can add functions to an existing API in Scala by using the Pimp my Library pattern. You do this by defining an implicit conversion function. For example, I have a class Vector3 to represent 3D vectors:

class Vector3 (val x: Float, val y: Float, val z: Float)

Suppose I want to be able to scale a vector by writing something like: 2.5f * v. I can't directly add a * method to class Float ofcourse, but I can supply an implicit conversion function like this:

implicit def scaleVector3WithFloat(f: Float) = new {
    def *(v: Vector3) = new Vector3(f * v.x, f * v.y, f * v.z)
}

Note that this returns an object of a structural type (the new { ... } construct) that contains the * method.

I haven't tested it, but I guess you could do something like this:

implicit def argMaxImplicit[A](t: Traversable[A]) = new {
    def argMax() = ...
}
偏爱你一生 2024-09-13 00:01:55

这是使用隐式构建器模式执行此操作的一种方法。与以前的解决方案相比,它的优点是它可以与任何 Traversable 一起使用,并返回类似的 Traversable。可悲的是,这是非常必要的。如果有人愿意,它可能会变成一个相当丑陋的折叠。

object RichTraversable {
  implicit def traversable2RichTraversable[A](t: Traversable[A]) = new RichTraversable[A](t)
}

class RichTraversable[A](t: Traversable[A]) { 
  def argMax[That, C](g: A => C)(implicit bf : scala.collection.generic.CanBuildFrom[Traversable[A], A, That], ord:Ordering[C]): That = {
    var minimum:C = null.asInstanceOf[C]
    val repr = t.repr
    val builder = bf(repr)
    for(a<-t){
      val test: C = g(a)
      if(test == minimum || minimum == null){
        builder += a
        minimum = test
      }else if (ord.gt(test, minimum)){
        builder.clear
        builder += a
        minimum = test
      }
    }
    builder.result
  }
}

Set(-2, -1, 0, 1, 2).argmax(x=>x*x) == Set(-2, 2)
List(-2, -1, 0, 1, 2).argmax(x=>x*x) == List(-2, 2)

Here's a way of doing so with the implicit builder pattern. It has the advantage over the previous solutions that it works with any Traversable, and returns a similar Traversable. Sadly, it's pretty imperative. If anyone wants to, it could probably be turned into a fairly ugly fold instead.

object RichTraversable {
  implicit def traversable2RichTraversable[A](t: Traversable[A]) = new RichTraversable[A](t)
}

class RichTraversable[A](t: Traversable[A]) { 
  def argMax[That, C](g: A => C)(implicit bf : scala.collection.generic.CanBuildFrom[Traversable[A], A, That], ord:Ordering[C]): That = {
    var minimum:C = null.asInstanceOf[C]
    val repr = t.repr
    val builder = bf(repr)
    for(a<-t){
      val test: C = g(a)
      if(test == minimum || minimum == null){
        builder += a
        minimum = test
      }else if (ord.gt(test, minimum)){
        builder.clear
        builder += a
        minimum = test
      }
    }
    builder.result
  }
}

Set(-2, -1, 0, 1, 2).argmax(x=>x*x) == Set(-2, 2)
List(-2, -1, 0, 1, 2).argmax(x=>x*x) == List(-2, 2)
像你 2024-09-13 00:01:55

这是一个基于 @Daniel 接受的答案的变体,也适用于 Sets。

def argMax[A, B: Ordering](input: GenIterable[A], f: A => B) : GenSet[A] = argMaxZip(input, f) map (_._1) toSet

def argMaxZip[A, B: Ordering](input: GenIterable[A], f: A => B): GenIterable[(A, B)] = {
  if (input.isEmpty) Nil
  else {
    val fPairs = input map (x => (x, f(x)))
    val maxF = fPairs.map(_._2).max
    fPairs filter (_._2 == maxF)
  }
}

当然,我们也可以做一种产生 (B, I​​terable[A]) 的变体。

Here's a variant loosely based on @Daniel's accepted answer that also works for Sets.

def argMax[A, B: Ordering](input: GenIterable[A], f: A => B) : GenSet[A] = argMaxZip(input, f) map (_._1) toSet

def argMaxZip[A, B: Ordering](input: GenIterable[A], f: A => B): GenIterable[(A, B)] = {
  if (input.isEmpty) Nil
  else {
    val fPairs = input map (x => (x, f(x)))
    val maxF = fPairs.map(_._2).max
    fPairs filter (_._2 == maxF)
  }
}

One could also do a variant that produces (B, Iterable[A]), of course.

め可乐爱微笑 2024-09-13 00:01:55

根据其他答案,您可以很容易地结合每个答案的优点(对 f() 的调用最少,等等)。在这里,我们对所有 Iterables 进行了隐式转换(因此它们可以透明地调用 .argmax()),并且如果出于某种原因首选独立方法。 ScalaTest 测试启动。

class Argmax[A](col: Iterable[A]) {
  def argmax[B](f: A => B)(implicit ord: Ordering[B]): Iterable[A] = {
    val mapped = col map f
    val max = mapped max ord
    (mapped zip col) filter (_._1 == max) map (_._2)
  }
}

object MathOps {
  implicit def addArgmax[A](col: Iterable[A]) = new Argmax(col)

  def argmax[A, B](col: Iterable[A])(f: A => B)(implicit ord: Ordering[B]) = {
    new Argmax(col) argmax f
  }
}

class MathUtilsTests extends FunSuite {
  import MathOps._

  test("Can argmax with unique") {
    assert((-10 to 0).argmax(_ * -1).toSet === Set(-10))
    // or alternate calling syntax
    assert(argmax(-10 to 0)(_ * -1).toSet === Set(-10))
  }

  test("Can argmax with multiple") {
    assert((-10 to 10).argmax(math.pow(_, 2)).toSet === Set(-10, 10))
  }
}

Based on other answers, you can pretty easily combine the strengths of each (minimal calls to f(), etc.). Here we have an implicit conversion for all Iterables (so they can just call .argmax() transparently), and a stand-alone method if for some reason that is preferred. ScalaTest tests to boot.

class Argmax[A](col: Iterable[A]) {
  def argmax[B](f: A => B)(implicit ord: Ordering[B]): Iterable[A] = {
    val mapped = col map f
    val max = mapped max ord
    (mapped zip col) filter (_._1 == max) map (_._2)
  }
}

object MathOps {
  implicit def addArgmax[A](col: Iterable[A]) = new Argmax(col)

  def argmax[A, B](col: Iterable[A])(f: A => B)(implicit ord: Ordering[B]) = {
    new Argmax(col) argmax f
  }
}

class MathUtilsTests extends FunSuite {
  import MathOps._

  test("Can argmax with unique") {
    assert((-10 to 0).argmax(_ * -1).toSet === Set(-10))
    // or alternate calling syntax
    assert(argmax(-10 to 0)(_ * -1).toSet === Set(-10))
  }

  test("Can argmax with multiple") {
    assert((-10 to 10).argmax(math.pow(_, 2)).toSet === Set(-10, 10))
  }
}
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文