Skip to content

Commit

Permalink
thoughts check + observe wip
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Dec 30, 2023
1 parent faabde5 commit ac94fb1
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 71 deletions.
5 changes: 4 additions & 1 deletion kyo-core/shared/src/main/scala/kyo/stats/attributes.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package kyo.stats

import scala.annotation.implicitNotFound
import kyo.stats.Attributes.AsAttribute

case class Attributes(get: List[Attributes.Attribute]) extends AnyVal {
def add(a: Attributes): Attributes =
Attributes(get ++ a.get)
def add[T](name: String, value: T)(implicit a: AsAttribute[T]): Attributes =
add(Attributes.add(name, value))
}

object Attributes {
val empty: Attributes = Attributes(Nil)

def of[T](name: String, value: T)(implicit a: AsAttribute[T]) =
def add[T](name: String, value: T)(implicit a: AsAttribute[T]) =
Attributes(a.f(name, value) :: Nil)

def all(l: List[Attributes]): Attributes =
Expand Down
28 changes: 14 additions & 14 deletions kyo-core/shared/src/test/scala/kyoTest/stats/AttributesTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,79 +13,79 @@ class AttributesTest extends KyoTest {
}

"one" in {
val attr = Attributes.of("test", true)
val attr = Attributes.add("test", true)
assert(attr.get.size == 1)
assert(attr.get.head.isInstanceOf[Attribute.BooleanAttribute])
}

"add" in {
val attr1 = Attributes.of("test1", true)
val attr2 = Attributes.of("test2", 123)
val attr1 = Attributes.add("test1", true)
val attr2 = Attributes.add("test2", 123)
val combined = attr1.add(attr2)
assert(combined.get.size == 2)
}

"all" in {
val attrList = List(Attributes.of("test1", true), Attributes.of("test2", 123.45))
val attrList = List(Attributes.add("test1", true), Attributes.add("test2", 123.45))
val combined = Attributes.all(attrList)
assert(combined.get.size == attrList.size)
}

"primitives" - {

"boolean" in {
val booleanAttr = Attributes.of("bool", true)
val booleanAttr = Attributes.add("bool", true)
assert(booleanAttr.get.head.isInstanceOf[Attribute.BooleanAttribute])
}

"int" in {
val booleanAttr = Attributes.of("int", 1)
val booleanAttr = Attributes.add("int", 1)
assert(booleanAttr.get.head.isInstanceOf[Attribute.LongAttribute])
}

"double" in {
val doubleAttr = Attributes.of("double", 123.45)
val doubleAttr = Attributes.add("double", 123.45)
assert(doubleAttr.get.head.isInstanceOf[Attribute.DoubleAttribute])
}

"long" in {
val longAttr = Attributes.of("long", 123L)
val longAttr = Attributes.add("long", 123L)
assert(longAttr.get.head.isInstanceOf[Attribute.LongAttribute])
}

"string" in {
val stringAttr = Attributes.of("string", "value")
val stringAttr = Attributes.add("string", "value")
assert(stringAttr.get.head.isInstanceOf[Attribute.StringAttribute])
}
}

"lists" - {
"boolean list" in {
val boolListAttr = Attributes.of("boolList", List(true, false, true))
val boolListAttr = Attributes.add("boolList", List(true, false, true))
assert(boolListAttr.get.head.isInstanceOf[Attribute.BooleanListAttribute])
assert(boolListAttr.get.head.asInstanceOf[Attribute.BooleanListAttribute].value.size == 3)
}

"integer list" in {
val intListAttr = Attributes.of("intList", List(1, 2, 3))
val intListAttr = Attributes.add("intList", List(1, 2, 3))
assert(intListAttr.get.head.isInstanceOf[Attribute.LongListAttribute])
assert(intListAttr.get.head.asInstanceOf[Attribute.LongListAttribute].value.size == 3)
}

"double list" in {
val doubleListAttr = Attributes.of("doubleList", List(1.1, 2.2, 3.3))
val doubleListAttr = Attributes.add("doubleList", List(1.1, 2.2, 3.3))
assert(doubleListAttr.get.head.isInstanceOf[Attribute.DoubleListAttribute])
assert(doubleListAttr.get.head.asInstanceOf[Attribute.DoubleListAttribute].value.size == 3)
}

"long list" in {
val longListAttr = Attributes.of("longList", List(100L, 200L, 300L))
val longListAttr = Attributes.add("longList", List(100L, 200L, 300L))
assert(longListAttr.get.head.isInstanceOf[Attribute.LongListAttribute])
assert(longListAttr.get.head.asInstanceOf[Attribute.LongListAttribute].value.size == 3)
}

"string list" in {
val stringListAttr = Attributes.of("stringList", List("a", "b", "c"))
val stringListAttr = Attributes.add("stringList", List("a", "b", "c"))
assert(stringListAttr.get.head.isInstanceOf[Attribute.StringListAttribute])
assert(stringListAttr.get.head.asInstanceOf[Attribute.StringListAttribute].value.size == 3)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CounterTest extends KyoTest {
for {
_ <- Counter.noop.inc
_ <- Counter.noop.add(1)
_ <- Counter.noop.add(1, Attributes.of("test", 1))
_ <- Counter.noop.add(1, Attributes.add("test", 1))
} yield succeed
}

Expand All @@ -20,7 +20,7 @@ class CounterTest extends KyoTest {
for {
_ <- counter.inc
_ <- counter.add(1)
_ <- counter.add(1, Attributes.of("test", 1))
_ <- counter.add(1, Attributes.add("test", 1))
} yield assert(unsafe.curr == 3)
}

Expand All @@ -39,7 +39,7 @@ class CounterTest extends KyoTest {
for {
_ <- counter.inc
_ <- counter.add(1)
_ <- counter.add(1, Attributes.of("test", 1))
_ <- counter.add(1, Attributes.add("test", 1))
} yield {
assert(unsafe1.curr == 3 && unsafe2.curr == 3)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class HistogramTest extends KyoTest {
"noop" in run {
for {
_ <- Histogram.noop.observe(1.0)
_ <- Histogram.noop.observe(1.0, Attributes.of("test", 1))
_ <- Histogram.noop.observe(1.0, Attributes.add("test", 1))
} yield succeed
}

Expand All @@ -18,7 +18,7 @@ class HistogramTest extends KyoTest {
val histogram = Histogram(unsafe)
for {
_ <- histogram.observe(1.0)
_ <- histogram.observe(1.0, Attributes.of("test", 1))
_ <- histogram.observe(1.0, Attributes.add("test", 1))
} yield assert(unsafe.observations == 2)
}

Expand All @@ -36,7 +36,7 @@ class HistogramTest extends KyoTest {
val histogram = Histogram.all(List(Histogram(unsafe1), Histogram(unsafe2)))
for {
_ <- histogram.observe(1.0)
_ <- histogram.observe(1.0, Attributes.of("test", 1))
_ <- histogram.observe(1.0, Attributes.add("test", 1))
} yield {
assert(unsafe1.observations == 2 && unsafe2.observations == 2)
}
Expand Down
10 changes: 5 additions & 5 deletions kyo-llm/shared/src/main/scala/kyo/llm/agents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ package object agents {
val output: Json[Out]
)

def info: Info
val info: Info

def thoughts: List[Thought.Info] = Nil
val thoughts: List[Thought.Info] = Nil

private val local = Locals.init(Option.empty[AI])

Expand All @@ -57,7 +57,7 @@ package object agents {
case None => AIs.init
}

private[kyo] def request: Schema = {
private[kyo] val schema: Schema = {
def schema[T](name: String, l: List[Thought.Info]): ZSchema[T] = {
val fields = l.map { t =>
import zio.schema.Schema._
Expand All @@ -68,7 +68,7 @@ package object agents {
Validation.succeed,
identity,
(_, _) => ListMap.empty
)
)
}
ZSchema.record(TypeId.fromTypeName(name), FieldSet(fields: _*)).asInstanceOf[ZSchema[T]]
}
Expand Down Expand Up @@ -153,7 +153,7 @@ package object agents {
"Call this agent with the result."
)

override def thoughts: List[Thought.Info] =
override val thoughts: List[Thought.Info] =
_thoughts

def run(input: T) =
Expand Down
2 changes: 1 addition & 1 deletion kyo-llm/shared/src/main/scala/kyo/llm/completions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ object completions {
ToolDef(FunctionDef(
p.info.description,
p.info.name,
p.request
p.schema
))
).toList)
Request(
Expand Down
85 changes: 57 additions & 28 deletions kyo-llm/shared/src/main/scala/kyo/llm/thoughts/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,45 @@ import kyo.ios.IOs

object Check {

private val stats = Thought.stats.scope("checks")
private val stats = Thought.stats.scope("check")
private val success = stats.initCounter("success")
private val failure = stats.initCounter("failure")

case class CheckFailed(path: List[String], invariant: String, analysis: String)
case class CheckFailed(ai: AI, thought: Thought, invariant: String, analysis: String)
extends RuntimeException

private def observe(path: List[String], result: Boolean) = {
private def observe(parent: Thought, field: String, result: Boolean) = {
val c = if (result) success else failure
c.attributes(Attributes.of("thought", path.last)).inc
c.attributes(
Attributes
.add("thought", parent.name)
.add("field", field)
).inc
}

private def warn(ai: AI, path: List[String], invariant: String): Unit < AIs =
private def warn(
ai: AI,
parent: Thought,
field: String,
invariant: String,
analysis: String = "Plase reason about the failure and fix any mistakes.",
repair: Option[Repair] = None
): Unit < AIs =
ai.systemMessage(
p"""
Thought Invariant Failure
=========================
Thought: ${parent.name}
Field: $field
Description: $invariant
Path: ${path.map(v => s"`$v`").mkString(".")}
Plase analyze and fix any mistakes.
Analysis: $analysis
${repair.map(pprint(_).plainText).map("Repair: " + _).getOrElse("")}
"""
)

case class Info(result: Boolean) extends Thought {
override def eval(path: List[String], ai: AI) =
observe(path, result)
override def eval(parent: Thought, field: String, ai: AI) =
observe(parent, field, result)
}

object Info {
Expand All @@ -46,9 +59,9 @@ object Check {
`Invariant check description`: Invarant,
`Invariant holds`: Boolean
) extends Thought {
override def eval(path: List[String], ai: AI) =
observe(path, `Invariant holds`).andThen {
warn(ai, path, `Invariant check description`)
override def eval(parent: Thought, field: String, ai: AI) =
observe(parent, field, `Invariant holds`).andThen {
warn(ai, parent, field, `Invariant check description`)
}
}

Expand All @@ -57,10 +70,21 @@ object Check {
`Invariant check analysis`: String,
`Invariant holds`: Boolean
) extends Thought {
override def eval(path: List[String], ai: AI) =
observe(path, `Invariant holds`).andThen {
warn(ai, path, `Invariant check description`).andThen {
IOs.fail(CheckFailed(path, `Invariant check description`, `Invariant check analysis`))
override def eval(parent: Thought, field: String, ai: AI) =
observe(parent, field, `Invariant holds`).andThen {
warn(
ai,
parent,
field,
`Invariant check description`,
`Invariant check analysis`
).andThen {
IOs.fail(CheckFailed(
ai,
parent,
`Invariant check description`,
`Invariant check analysis`
))
}
}
}
Expand All @@ -70,21 +94,26 @@ object Check {
`Invariant check analysis`: String,
`Invariant holds`: Boolean
) extends Thought {
override def eval(path: List[String], ai: AI) =
observe(path, `Invariant holds`).andThen {
override def eval(parent: Thought, field: String, ai: AI) =
observe(parent, field, `Invariant holds`).andThen {
AIs.ephemeral {
warn(ai, path, `Invariant check description`).andThen {
ai.gen[Repair]("Provide a repair for the failed thought invariant.")
warn(
ai,
parent,
field,
`Invariant check description`,
`Invariant check analysis`
).andThen {
ai.gen[Repair]("Provide a repair for the last failed thought invariant.")
}
}.map { repair =>
ai.systemMessage(
p"""
Thought Invariant Repair
========================
Description: ${`Invariant check description`}
Path: ${path.map(v => s"`$v`").mkString(".")}
Inferred Repair: ${pprint(repair)}
"""
warn(
ai,
parent,
field,
`Invariant check description`,
`Invariant check analysis`,
Some(repair)
)
}
}
Expand Down
Loading

0 comments on commit ac94fb1

Please sign in to comment.