Skip to content

Commit

Permalink
feat: [+] #105 improve raiseError with location (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
eruizalo authored Oct 3, 2022
1 parent 338b337 commit 695f3dc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
package doric
package syntax

import doric.sem.Location
import org.apache.spark.sql.{functions => f}

private[syntax] trait StringColumns31 {

/**
* Throws an exception with the provided error message.
*
* @throws java.lang.RuntimeException with the error message
* @group String Type
* @see [[org.apache.spark.sql.functions.raise_error]]
*/
def raiseError(str: String)(implicit l: Location): NullColumn =
str.lit.raiseError

implicit class StringOperationsSyntax31(s: DoricColumn[String]) {

/**
Expand All @@ -20,6 +31,8 @@ private[syntax] trait StringColumns31 {
* @group String Type
* @see [[org.apache.spark.sql.functions.raise_error]]
*/
def raiseError: NullColumn = s.elem.map(f.raise_error).toDC
def raiseError(implicit l: Location): NullColumn =
ds"""$s
located at . ${l.getLocation.lit}""".elem.map(f.raise_error).toDC
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package doric
package syntax

import org.scalatest.EitherValues
import org.scalatest.{Assertion, EitherValues}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.sql.{functions => f}
import org.apache.spark.sql.types.NullType

Expand All @@ -15,7 +14,18 @@ class StringColumns31Spec
describe("raiseError doric function") {
import spark.implicits._

val df = List("this is an error").toDF("errorMsg")
lazy val errorMsg = "this is an error"
lazy val df = List(errorMsg).toDF("errorMsg")

def validateExceptions(
doricExc: RuntimeException,
sparkExc: RuntimeException
): Assertion = {
// doricExc.getMessage should fullyMatch regex
// s"""${sparkExc.getMessage}
// located at . (${this.getClass.getSimpleName}.scala:33)"""
doricExc.getMessage should startWith(sparkExc.getMessage)
}

it("should work as spark raise_error function") {
import java.lang.{RuntimeException => exception}
Expand All @@ -30,7 +40,23 @@ class StringColumns31Spec
df.select(f.raise_error(f.col("errorMsg"))).collect()
}

doricErr.getMessage shouldBe sparkErr.getMessage
validateExceptions(doricErr, sparkErr)
}

it("should be available for strings") {
import java.lang.{RuntimeException => exception}

val doricErr = intercept[exception] {
val res = df.select(raiseError(errorMsg))

res.schema.head.dataType shouldBe NullType
res.collect()
}
val sparkErr = intercept[exception] {
df.select(f.raise_error(f.col("errorMsg"))).collect()
}

validateExceptions(doricErr, sparkErr)
}
}

Expand Down

0 comments on commit 695f3dc

Please sign in to comment.