Skip to content

Commit

Permalink
feat: [~] move round function by spark needs
Browse files Browse the repository at this point in the history
  • Loading branch information
eruizalo committed Aug 12, 2023
1 parent b43dec0 commit 1f144e7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 30 deletions.
29 changes: 1 addition & 28 deletions core/src/main/scala/doric/syntax/NumericColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ package syntax
import cats.implicits._
import doric.DoricColumn.sparkFunction
import doric.types.{CollectionType, NumericType}
import org.apache.spark.sql.catalyst.expressions.{BRound, Expression, FormatNumber, FromUnixTime, Rand, Randn, Round, RoundBase, UnaryMinus}
import org.apache.spark.sql.catalyst.expressions.{BRound, FormatNumber, FromUnixTime, Rand, Randn, Round, RoundBase, UnaryMinus}
import org.apache.spark.sql.{Column, functions => f}

import scala.math.BigDecimal.RoundingMode.RoundingMode

protected trait NumericColumns {

/**
Expand Down Expand Up @@ -599,31 +597,6 @@ protected trait NumericColumns {
.mapN((c, s) => new Column(Round(c.expr, s.expr)))
.toDC

/**
* DORIC EXCLUSIVE! Round the value to `scale` decimal places with given round `mode`
* if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @todo decimal type
* @group Numeric Type
*/
def round(scale: IntegerColumn, mode: RoundingMode): DoricColumn[T] = {
case class DoricRound(
child: Expression,
scale: Expression,
mode: RoundingMode
) extends RoundBase(child, scale, mode, s"ROUND_$mode") {
override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression
): DoricRound =
copy(child = newLeft, scale = newRight)
}

(column.elem, scale.elem)
.mapN((c, s) => new Column(DoricRound(c.expr, s.expr, mode)))
.toDC
}

/**
* Returns col1 if it is not NaN, or col2 if col1 is NaN.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package syntax
import cats.implicits._
import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => f}
import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned}
import org.apache.spark.sql.catalyst.expressions.{Expression, RoundBase, ShiftLeft, ShiftRight, ShiftRightUnsigned}

import scala.math.BigDecimal.RoundingMode.RoundingMode

protected trait NumericColumns2_31 {

Expand Down Expand Up @@ -57,4 +59,31 @@ protected trait NumericColumns2_31 {
def bitwiseNot: DoricColumn[T] = column.elem.map(f.bitwiseNOT).toDC
}

/**
* NUM WITH DECIMALS OPERATIONS
*/
implicit class NumWithDecimalsOperationsSyntax2_31[T: NumWithDecimalsType](
column: DoricColumn[T]
) {

/**
* DORIC EXCLUSIVE! Round the value to `scale` decimal places with given round `mode`
* if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @todo decimal type
* @group Numeric Type
*/
def round(scale: IntegerColumn, mode: RoundingMode): DoricColumn[T] = {
case class DoricRound(
child: Expression,
scale: Expression,
mode: RoundingMode
) extends RoundBase(child, scale, mode, s"ROUND_$mode")

(column.elem, scale.elem)
.mapN((c, s) => new Column(DoricRound(c.expr, s.expr, mode)))
.toDC
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package syntax
import cats.implicits._

import org.apache.spark.sql.{Column, functions => f}
import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned}
import org.apache.spark.sql.catalyst.expressions.{Expression, RoundBase, ShiftLeft, ShiftRight, ShiftRightUnsigned}

import scala.math.BigDecimal.RoundingMode.RoundingMode

protected trait NumericColumns32 {

Expand Down Expand Up @@ -57,4 +59,37 @@ protected trait NumericColumns32 {
def bitwiseNot: DoricColumn[T] = column.elem.map(f.bitwise_not).toDC
}

/**
* NUM WITH DECIMALS OPERATIONS
*/
implicit class NumWithDecimalsOperationsSyntax32[T: NumWithDecimalsType](
column: DoricColumn[T]
) {

/**
* DORIC EXCLUSIVE! Round the value to `scale` decimal places with given round `mode`
* if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @todo decimal type
* @group Numeric Type
*/
def round(scale: IntegerColumn, mode: RoundingMode): DoricColumn[T] = {
case class DoricRound(
child: Expression,
scale: Expression,
mode: RoundingMode
) extends RoundBase(child, scale, mode, s"ROUND_$mode") {
override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression
): DoricRound =
copy(child = newLeft, scale = newRight)
}

(column.elem, scale.elem)
.mapN((c, s) => new Column(DoricRound(c.expr, s.expr, mode)))
.toDC
}
}

}

0 comments on commit 1f144e7

Please sign in to comment.