diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 8e7a75e437..da5225800c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -22,8 +22,7 @@ import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedRoute, BlindedRouteDe import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomBytes, randomKey} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomBytes, randomKey} -import org.json4s.JsonAST._ +import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods import org.scalatest.funsuite.AnyFunSuite import scodec.bits._ @@ -453,22 +452,22 @@ class SphinxSpec extends AnyFunSuite { assert(decryptionError == expected) } + case class TestVector(encodedFailureMessage: String, hops: Seq[TestHop]) + case class TestHop(sharedSecret: String, encryptedMessage: String) + implicit val formats: DefaultFormats.type = DefaultFormats + test("attributable error test vector") { val src = Source.fromFile(new File(getClass.getResource(s"/attributable_error.json").getFile)) - try { - val testVector = JsonMethods.parse(src.mkString).asInstanceOf[JObject].values - val encodedFailureMessage = ByteVector.fromValidHex(testVector("encodedFailureMessage").asInstanceOf[String]) - val expected = FailureMessageCodecs.failureOnionPayload(0).decode(encodedFailureMessage.bits).require.value - val hops = testVector("hops").asInstanceOf[List[Map[String, String]]] - val sharedSecrets = hops.map(hop => ByteVector32(ByteVector.fromValidHex(hop("sharedSecret")))).reverse - val encryptedMessage = hops.map(hop => ByteVector.fromValidHex(hop("encryptedMessage"))).last - val nodeIds = (1 to 5).map(_ => randomKey().publicKey) - val Right(DecryptedFailurePacket(originNode, failureMessage)) = AttributableErrorPacket.decrypt(encryptedMessage, sharedSecrets.zip(nodeIds)) - assert(originNode == nodeIds.last) - assert(failureMessage == expected) - } finally { - src.close() - } + val testVector = JsonMethods.parse(src.mkString).extract[TestVector] + src.close() + val encodedFailureMessage = ByteVector.fromValidHex(testVector.encodedFailureMessage) + val expected = FailureMessageCodecs.failureOnionPayload(0).decode(encodedFailureMessage.bits).require.value + val sharedSecrets = testVector.hops.map(hop => ByteVector32(ByteVector.fromValidHex(hop.sharedSecret))).reverse + val encryptedMessage = testVector.hops.map(hop => ByteVector.fromValidHex(hop.encryptedMessage)).last + val nodeIds = (1 to testVector.hops.length).map(_ => randomKey().publicKey) + val Right(DecryptedFailurePacket(originNode, failureMessage)) = AttributableErrorPacket.decrypt(encryptedMessage, sharedSecrets.zip(nodeIds)) + assert(originNode == nodeIds.last) + assert(failureMessage == expected) } test("create blinded route (reference test vector)") {