Skip to content

Commit

Permalink
add XsJavaLocalDate as an alternative to XsDate if useJavaTime key se…
Browse files Browse the repository at this point in the history
…tting in sbt is true.
  • Loading branch information
khajavi committed Nov 1, 2020
1 parent a1c99e8 commit a9843e9
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 51 deletions.
14 changes: 13 additions & 1 deletion cli/src/main/resources/scalaxb.scala.template
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package scalaxb

import java.time.LocalDate
import javax.xml.datatype.{DatatypeFactory, XMLGregorianCalendar}
import scala.xml.{Node, NodeSeq, NamespaceBinding, Elem, UnprefixedAttribute, PrefixedAttribute}
import javax.xml.datatype.{XMLGregorianCalendar}
import javax.xml.namespace.QName
import javax.xml.bind.DatatypeConverter

Expand Down Expand Up @@ -209,6 +210,17 @@ trait XMLStandardTypes {
Helper.stringToXML(obj.toXMLFormat, namespace, elementLabel, scope)
}

implicit lazy val __JavaLocalDateXMLFormat: XMLFormat[java.time.LocalDate] = new XMLFormat[LocalDate] {
def localDate (d: XMLGregorianCalendar) : LocalDate = LocalDate.of(d.getYear, d.getMonth, d.getDay)
def gregorianCalendar (d: LocalDate): XMLGregorianCalendar = Helper.toCalendar(d.toString)

def reads(seq: scala.xml.NodeSeq, stack: List[ElemName]): Either[String, LocalDate] =
implicitly[XMLFormat[XMLGregorianCalendar]].reads(seq, stack).map(localDate)

override def writes(obj: LocalDate, namespace: Option[String], elementLabel: Option[String], scope: NamespaceBinding, typeAttribute: Boolean): NodeSeq =
implicitly[XMLFormat[XMLGregorianCalendar]].writes(gregorianCalendar(obj), namespace, elementLabel, scope, typeAttribute)
}

implicit lazy val __GregorianCalendarXMLWriter: CanWriteXML[java.util.GregorianCalendar] = new CanWriteXML[java.util.GregorianCalendar] {
def writes(obj: java.util.GregorianCalendar, namespace: Option[String], elementLabel: Option[String],
scope: scala.xml.NamespaceBinding, typeAttribute: Boolean): scala.xml.NodeSeq =
Expand Down
2 changes: 2 additions & 0 deletions cli/src/main/scala/scalaxb/compiler/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ case class Config(items: Map[String, ConfigEntry]) {
def symbolEncodingStrategy = get[SymbolEncoding.Strategy] getOrElse defaultSymbolEncodingStrategy
def enumNameMaxLength: Int = (get[EnumNameMaxLength] getOrElse defaultEnumNameMaxLength).value
def useLists: Boolean = values contains UseLists
def useJavaTime: Boolean = values contains UseJavaTime

private def get[A <: ConfigEntry: Manifest]: Option[A] =
items.get(implicitly[Manifest[A]].runtimeClass.getName).asInstanceOf[Option[A]]
Expand Down Expand Up @@ -158,6 +159,7 @@ object ConfigEntry {
case object CapitalizeWords extends ConfigEntry
case class EnumNameMaxLength(value: Int) extends ConfigEntry
case object UseLists extends ConfigEntry
case object UseJavaTime extends ConfigEntry

object SymbolEncoding {
sealed abstract class Strategy(val alias: String, val description: String) extends ConfigEntry with Product with Serializable {
Expand Down
46 changes: 25 additions & 21 deletions cli/src/main/scala/scalaxb/compiler/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@

package scalaxb.compiler

import java.net.{URI}
import scala.xml.{Node, Elem, UnprefixedAttribute, NamespaceBinding}
import scala.xml.factory.{XMLLoader}
import java.net.URI

import scala.xml.{Elem, NamespaceBinding, Node, UnprefixedAttribute}
import scala.xml.factory.XMLLoader
import javax.xml.parsers.SAXParser
import java.io.{File, PrintWriter, Reader, BufferedReader}
import scala.collection.mutable
import java.io.{BufferedReader, File, PrintWriter, Reader}

import scala.collection.mutable.{ListBuffer, ListMap}
import ConfigEntry._
import scalaxb.compiler.xsd.ParserConfig

object Snippet {
def apply(snippets: Snippet*): Snippet =
Expand Down Expand Up @@ -117,7 +119,7 @@ trait Module {
def includeLocations: Seq[String]
def raw: RawSchema
def location: URI
def toSchema(context: Context): Schema
def toSchema(context: Context, config: ParserConfig): Schema
def swapTargetNamespace(outerNamespace: Option[String], n: Int): Importable
}

Expand Down Expand Up @@ -201,26 +203,28 @@ trait Module {

def processReaders[From, To](files: Seq[From], config: Config)
(implicit ev: CanBeRawSchema[From, RawSchema], evTo: CanBeWriter[To]): (CompileSource[From], List[To]) = {
val source = buildCompileSource(files)
val source = buildCompileSource(files, config)
(source, processCompileSource(source, config))
}

def buildCompileSource[From, To](files: Seq[From])
def buildCompileSource[From, To](files: Seq[From], config: Config)
(implicit ev: CanBeRawSchema[From, RawSchema]): CompileSource[From] = {

logger.debug("%s", files.toString())
val context = buildContext
val importables0 = ListMap[From, Importable](files map { f =>
f -> toImportable(ev.toURI(f), ev.toRawSchema(f))}: _*)
val importables = ListBuffer[(Importable, From)](files map { f => importables0(f) -> f }: _*)
val parserConfig = new ParserConfig
parserConfig.useJavaTime = config.useJavaTime
val schemas = ListMap[Importable, Schema](importables map { case (importable, file) =>
val s = parse(importable, context)
val s = parse(importable, context, parserConfig)
(importable, s) } toSeq: _*)

val additionalImportables = ListMap.empty[Importable, File]

// recursively add missing files
def addMissingFiles(): Unit = {
def addMissingFiles(parserConfig: ParserConfig): Unit = {
val current = (importables map {_._1}) ++ additionalImportables.keysIterator.toList
// check for all dependencies before proceeding.
val missings = (current flatMap { importable =>
Expand All @@ -239,12 +243,12 @@ trait Module {
added = true
val importable = toImportable(implicitly[CanBeRawSchema[File, RawSchema]].toURI(x),
implicitly[CanBeRawSchema[File, RawSchema]].toRawSchema(x))
val s = parse(importable, context)
val s = parse(importable, context, parserConfig)
schemas(importable) = s
(importable, x) })
if (added) addMissingFiles()
if (added) addMissingFiles(parserConfig)
}
def processUnnamedIncludes(): Unit = {
def processUnnamedIncludes(parserConfig: ParserConfig): Unit = {
logger.debug("processUnnamedIncludes")
val all = (importables.toList map {_._1}) ++ (additionalImportables.toList map {_._1})
val parents: ListBuffer[Importable] = ListBuffer(all filter { !_.includeLocations.isEmpty}: _*)
Expand All @@ -270,7 +274,7 @@ trait Module {
logger.debug("processUnnamedIncludes - setting %s's outer namespace to %s", x.location, tnsstr)
count += 1
val swap = x.swapTargetNamespace(tns, count)
schemas(swap) = parse(swap, context)
schemas(swap) = parse(swap, context, parserConfig)
additionalImportables(swap) = new File(swap.location.getPath)
used += x
}
Expand All @@ -292,8 +296,8 @@ trait Module {
}
}

addMissingFiles()
processUnnamedIncludes()
addMissingFiles(parserConfig)
processUnnamedIncludes(parserConfig)
CompileSource(context, schemas, importables, additionalImportables,
importables0(files.head).targetNamespace)
}
Expand Down Expand Up @@ -426,11 +430,11 @@ trait Module {

def nodeToRawSchema(node: Node): RawSchema

def parse(importable: Importable, context: Context): Schema
= importable.toSchema(context)
def parse(importable: Importable, context: Context, config: ParserConfig): Schema
= importable.toSchema(context, config)

def parse(location: URI, in: Reader): Schema
= parse(toImportable(location, readerToRawSchema(in)), buildContext)
def parse(location: URI, in: Reader, config: ParserConfig): Schema
= parse(toImportable(location, readerToRawSchema(in)), buildContext, config)

def printNodes(nodes: Seq[Node], out: PrintWriter): Unit = {
import scala.xml._
Expand Down Expand Up @@ -487,7 +491,7 @@ trait Module {
NamespaceBinding(null, outerNamespace getOrElse null, scope)
def fixSeq(ns: Seq[Node]): Seq[Node] =
for { node <- ns } yield node match {
case elem: Elem =>
case elem: Elem =>
elem.copy(scope = fixScope(elem.scope),
child = fixSeq(elem.child))
case other => other
Expand Down
6 changes: 3 additions & 3 deletions cli/src/main/scala/scalaxb/compiler/wsdl11/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import java.net.URI

import scala.xml.Node
import scala.reflect.ClassTag
import scalaxb.compiler.xsd.{GenProtocol, SchemaDecl, SchemaLite, XsdContext}
import scalaxb.compiler.xsd.{GenProtocol, ParserConfig, SchemaDecl, SchemaLite, XsdContext}

import scala.util.matching.Regex

Expand Down Expand Up @@ -166,14 +166,14 @@ class Driver extends Module { driver =>
schemaLite.includes map { _.schemaLocation }
}

def toSchema(context: Context): WsdlPair = {
def toSchema(context: Context, config: ParserConfig): WsdlPair = {
wsdl foreach { wsdl =>
logger.debug(wsdl.toString)
context.definitions += wsdl
}

val xsd = xsdRawSchema map { x =>
val schema = SchemaDecl.fromXML(x, context.xsdcontext)
val schema = SchemaDecl.fromXML(x, context.xsdcontext, config)
logger.debug(schema.toString)
schema
}
Expand Down
8 changes: 4 additions & 4 deletions cli/src/main/scala/scalaxb/compiler/wsdl11/GenSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ trait {interfaceTypeName} {{
def buildPartArg(part: XPartType, selector: String): String =
(part.typeValue, part.element) match {
case (Some(typeValueQName), _) =>
val typeSymbol = toTypeSymbol(typeValueQName)
val typeSymbol = toTypeSymbol(typeValueQName, config.useJavaTime)
xsdgenerator.buildArg(xsdgenerator.buildTypeName(typeSymbol), selector, Single, None)
case (_, Some(elementQName)) =>
val elem = xsdgenerator.elements(splitTypeName(elementQName))
Expand Down Expand Up @@ -718,17 +718,17 @@ trait {interfaceTypeName} {{
def toParamCache(part: XPartType): ParamCache =
part.typeValue map { typeValue =>
val name = camelCase(part.name getOrElse "in")
ParamCache(name, toTypeSymbol(typeValue), Single, false, false)
ParamCache(name, toTypeSymbol(typeValue, config.useJavaTime), Single, false, false)
} getOrElse {
part.element map { element =>
val param = xsdgenerator.buildParam(xsdgenerator.elements(splitTypeName(element))) map {camelCase}
ParamCache(param.toParamName, param.typeSymbol, param.cardinality, param.nillable, false)
} getOrElse {sys.error("part does not have either type or element: " + part.toString)}
}

def toTypeSymbol(qname: javax.xml.namespace.QName): XsTypeSymbol = {
def toTypeSymbol(qname: javax.xml.namespace.QName, useJavaTime: Boolean): XsTypeSymbol = {
import scalaxb.compiler.xsd.{ReferenceTypeSymbol, TypeSymbolParser}
val symbol = TypeSymbolParser.fromQName(qname)
val symbol = TypeSymbolParser.fromQName(qname, useJavaTime)
symbol match {
case symbol: ReferenceTypeSymbol =>
val (namespace, typeName) = splitTypeName(qname)
Expand Down
17 changes: 9 additions & 8 deletions cli/src/main/scala/scalaxb/compiler/xsd/Decl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ParserConfig {
var targetNamespace: Option[String] = None
var elementQualifiedDefault: Boolean = false
var attributeQualifiedDefault: Boolean = false
var useJavaTime: Boolean = false
val topElems = mutable.ListMap.empty[String, ElemDecl]
val elemList = mutable.ListBuffer.empty[ElemDecl]
val topTypes = mutable.ListMap.empty[String, TypeDecl]
Expand All @@ -78,16 +79,16 @@ object TypeSymbolParser {
val XML_URI = "http://www.w3.org/XML/1998/namespace"

def fromString(name: String, scope: NamespaceBinding, config: ParserConfig): XsTypeSymbol =
fromString(splitTypeName(name, scope, config.targetNamespace))
fromString(splitTypeName(name, scope, config.targetNamespace), config.useJavaTime)

def fromQName(qname: javax.xml.namespace.QName): XsTypeSymbol =
fromString((masked.scalaxb.Helper.nullOrEmpty(qname.getNamespaceURI), qname.getLocalPart))
def fromQName(qname: javax.xml.namespace.QName, useJavaTime: Boolean): XsTypeSymbol =
fromString((masked.scalaxb.Helper.nullOrEmpty(qname.getNamespaceURI), qname.getLocalPart), useJavaTime)

def fromString(pair: (Option[String], String)): XsTypeSymbol = {
def fromString(pair: (Option[String], String), useJavaTime: Boolean): XsTypeSymbol = {
val (namespace, localPart) = pair
namespace match {
case Some(XML_SCHEMA_URI) =>
if (XsTypeSymbol.toTypeSymbol.isDefinedAt(localPart)) XsTypeSymbol.toTypeSymbol(localPart)
if (XsTypeSymbol.toTypeSymbol(useJavaTime).isDefinedAt(localPart)) XsTypeSymbol.toTypeSymbol(useJavaTime)(localPart)
else ReferenceTypeSymbol(namespace, localPart)
case _ => ReferenceTypeSymbol(namespace, localPart)
}
Expand Down Expand Up @@ -147,8 +148,8 @@ case class SchemaDecl(targetNamespace: Option[String],

object SchemaDecl {
def fromXML(node: scala.xml.Node,
context: XsdContext,
config: ParserConfig = new ParserConfig) = {
context: XsdContext,
config: ParserConfig) = {
val schema = (node \\ "schema").headOption.getOrElse {
sys.error("xsd: schema element not found: " + node.toString) }
val targetNamespace = schema.attribute("targetNamespace").headOption map { _.text }
Expand Down Expand Up @@ -492,7 +493,7 @@ object SimpleTypeDecl {
def fromXML(node: scala.xml.Node, name: String, family: List[String], config: ParserConfig): SimpleTypeDecl = {
var content: ContentTypeDecl = null
for (child <- node.child) child match {
case <restriction>{ _* }</restriction> => content = SimpTypRestrictionDecl.fromXML(child, family, config)
case <restriction>{ _* }</restriction> => content = SimpTypRestrictionDecl.fromXML(child, family, config)
case <list>{ _* }</list> => content = SimpTypListDecl.fromXML(child, family, config)
case <union>{ _* }</union> => content = SimpTypUnionDecl.fromXML(child, config)
case _ =>
Expand Down
4 changes: 2 additions & 2 deletions cli/src/main/scala/scalaxb/compiler/xsd/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class Driver extends Module { driver =>
}
val includeLocations: Seq[String] = schemaLite.includes map { _.schemaLocation }

def toSchema(context: Context): Schema = {
val schema = SchemaDecl.fromXML(raw, context)
def toSchema(context: Context, config: ParserConfig): Schema = {
val schema = SchemaDecl.fromXML(raw, context, config)
logger.debug("toSchema: " + schema.toString())
schema
}
Expand Down
24 changes: 13 additions & 11 deletions cli/src/main/scala/scalaxb/compiler/xsd/XsTypeSymbol.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
/*
* Copyright (c) 2010 e.e d3si9n
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
Expand All @@ -19,15 +19,15 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

package scalaxb.compiler.xsd

import javax.xml.namespace.QName

trait XsTypeSymbol extends scala.xml.TypeSymbol {
val name: String
override def toString(): String = name

override def toString(): String = name
}

object XsAnyType extends XsTypeSymbol {
Expand Down Expand Up @@ -116,6 +116,7 @@ object XsDuration extends BuiltInSimpleTypeSymbol("javax.xml.datatype.Du
object XsDateTime extends BuiltInSimpleTypeSymbol("javax.xml.datatype.XMLGregorianCalendar") {}
object XsTime extends BuiltInSimpleTypeSymbol("javax.xml.datatype.XMLGregorianCalendar") {}
object XsDate extends BuiltInSimpleTypeSymbol("javax.xml.datatype.XMLGregorianCalendar") {}
object XsJavaLocalDate extends BuiltInSimpleTypeSymbol("java.time.LocalDate")
object XsGYearMonth extends BuiltInSimpleTypeSymbol("javax.xml.datatype.XMLGregorianCalendar") {}
object XsGYear extends BuiltInSimpleTypeSymbol("javax.xml.datatype.XMLGregorianCalendar") {}
object XsGMonthDay extends BuiltInSimpleTypeSymbol("javax.xml.datatype.XMLGregorianCalendar") {}
Expand Down Expand Up @@ -160,14 +161,15 @@ object XsUnsignedByte extends BuiltInSimpleTypeSymbol("Int") {}
object XsTypeSymbol {
type =>?[A, B] = PartialFunction[A, B]
val LOCAL_ELEMENT = "http://scalaxb.org/local-element"
val toTypeSymbol: String =>? XsTypeSymbol = {

def toTypeSymbol(useJavaTime: Boolean): String =>? XsTypeSymbol = {
case "anyType" => XsAnyType
case "anySimpleType" => XsAnySimpleType
case "duration" => XsDuration
case "dateTime" => XsDateTime
case "time" => XsTime
case "date" => XsDate
case "date" if useJavaTime => XsJavaLocalDate
case "date" if !useJavaTime => XsDate
case "gYearMonth" => XsGYearMonth
case "gYear" => XsGYear
case "gMonthDay" => XsGMonthDay
Expand Down Expand Up @@ -207,6 +209,6 @@ object XsTypeSymbol {
case "short" => XsShort
case "unsignedShort" => XsUnsignedShort
case "byte" => XsByte
case "unsignedByte" => XsUnsignedByte
}
case "unsignedByte" => XsUnsignedByte
}
}
1 change: 1 addition & 0 deletions sbt-scalaxb/src/main/scala/sbtscalaxb/ScalaxbKeys.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ trait ScalaxbKeys {
lazy val scalaxbSymbolEncodingStrategy = settingKey[SymbolEncodingStrategy.Value]("Specifies the strategy to encode non-identifier characters in generated class names")
lazy val scalaxbEnumNameMaxLength = settingKey[Int]("Truncates names of enum members longer than this value (default: 50)")
lazy val scalaxbUseLists = settingKey[Boolean]("Declare sequences with concrete type List instead of Seq")
lazy val scalaxbUseJavaTime = settingKey[Boolean]("Use Java Time (java.time.*) instead of XMLGregorianCalendar (javax.xml.datatype.*)")

object HttpClientType extends Enumeration {
val None, Dispatch, Gigahorse = Value
Expand Down
4 changes: 3 additions & 1 deletion sbt-scalaxb/src/main/scala/sbtscalaxb/ScalaxbPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ object ScalaxbPlugin extends sbt.AutoPlugin {
scalaxbSymbolEncodingStrategy := SymbolEncodingStrategy.Legacy151,
scalaxbEnumNameMaxLength := 50,
scalaxbUseLists := false,
scalaxbUseJavaTime := false,
scalaxbConfig :=
ScConfig(
Vector(PackageNames(scalaxbCombinedPackageNames.value)) ++
Expand Down Expand Up @@ -136,7 +137,8 @@ object ScalaxbPlugin extends sbt.AutoPlugin {
(if (scalaxbCapitalizeWords.value) Vector(CapitalizeWords) else Vector()) ++
Vector(SymbolEncoding.withName(scalaxbSymbolEncodingStrategy.value.toString)) ++
Vector(EnumNameMaxLength(scalaxbEnumNameMaxLength.value)) ++
(if (scalaxbUseLists.value) Vector(UseLists) else Vector())
(if (scalaxbUseLists.value) Vector(UseLists) else Vector()) ++
(if (scalaxbUseJavaTime.value) Vector(UseJavaTime) else Vector())
)
))
}

0 comments on commit a9843e9

Please sign in to comment.