diff --git a/build.sbt b/build.sbt index 54924d9d..a42cf377 100644 --- a/build.sbt +++ b/build.sbt @@ -195,7 +195,21 @@ lazy val core = project ProblemFilters.exclude[IncompatibleResultTypeProblem]( "sangria.schema.WithInputTypeRendering.deprecationTracker"), ProblemFilters.exclude[ReversedMissingMethodProblem]( - "sangria.schema.WithInputTypeRendering.deprecationTracker") + "sangria.schema.WithInputTypeRendering.deprecationTracker"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.validation.RuleBasedQueryValidator.validateInputDocument"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.validation.RuleBasedQueryValidator.validateInputDocument"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.schema.SchemaChange#AbstractAstDirectiveAdded.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.schema.SchemaChange#InputObjectTypeAstDirectiveAdded.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.schema.SchemaChange#InputObjectTypeAstDirectiveAdded.this"), + ProblemFilters.exclude[MissingTypesProblem]( + "sangria.schema.SchemaChange$InputObjectTypeAstDirectiveAdded$"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.schema.SchemaChange#InputObjectTypeAstDirectiveAdded.apply") ), Test / testOptions += Tests.Argument(TestFrameworks.ScalaTest, "-oF"), libraryDependencies ++= Seq( diff --git a/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala b/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala index f0169695..592cc8c5 100644 --- a/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala +++ b/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala @@ -98,7 +98,7 @@ class OverlappingFieldsCanBeMergedBenchmark { bh.consume(doValidate(validator, deepAbstractConcrete)) private def doValidate(validator: QueryValidator, document: Document): Vector[Violation] = { - val result = validator.validateQuery(schema, document, None) + val result = validator.validateQuery(schema, document, Map.empty, None) require(result.isEmpty) result } diff --git a/modules/core/src/main/scala/sangria/execution/Executor.scala b/modules/core/src/main/scala/sangria/execution/Executor.scala index aa166324..c6573105 100644 --- a/modules/core/src/main/scala/sangria/execution/Executor.scala +++ b/modules/core/src/main/scala/sangria/execution/Executor.scala @@ -29,107 +29,123 @@ case class Executor[Ctx, Root]( operationName: Option[String] = None, variables: Input = emptyMapVars )(implicit um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = { - val (violations, validationTiming) = - TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit)) + val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) + val valueCollector = new ValueCollector[Ctx, Input]( + schema, + variables, + queryAst.sourceMapper, + deprecationTracker, + userContext, + exceptionHandler, + scalarMiddleware, + false)(um) - if (violations.nonEmpty) - Future.failed(ValidationError(violations, exceptionHandler)) - else { - val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) - val valueCollector = new ValueCollector[Ctx, Input]( - schema, - variables, - queryAst.sourceMapper, - deprecationTracker, - userContext, - exceptionHandler, + val operationCtx = for { + operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) + unmarshalledVariables <- valueCollector.getVariableValues( + operation.variables, scalarMiddleware, - false)(um) + errorsLimit + ) + } yield (operation, unmarshalledVariables) - val executionResult = for { - operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) - unmarshalledVariables <- valueCollector.getVariableValues( - operation.variables, - scalarMiddleware, - errorsLimit - ) - fieldCollector = new FieldCollector[Ctx, Root]( - schema, - queryAst, - unmarshalledVariables, - queryAst.sourceMapper, - valueCollector, - exceptionHandler) - tpe <- Executor.getOperationRootType( - schema, - exceptionHandler, - operation, - queryAst.sourceMapper) - fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) - } yield { - val preparedFields = fields.fields.flatMap { - case CollectedField(_, astField, Success(_)) => - val allFields = - tpe.getField(schema, astField.name).asInstanceOf[Vector[Field[Ctx, Root]]] - val field = allFields.head - val args = valueCollector.getFieldArgumentValues( - ExecutionPath.empty.add(astField, tpe), - Some(astField), - field.arguments, - astField.arguments, - unmarshalledVariables) + operationCtx match { + case Failure(error) => + // return validation errors without variables first if variables is what failed + val violations = queryValidator.validateQuery(schema, queryAst, Map.empty, errorsLimit) - args.toOption.map(PreparedField(field, _)) - case _ => None - } + if (violations.nonEmpty) + Future.failed(ValidationError(violations, exceptionHandler)) + else + Future.failed(error) + case Success((operation, unmarshalledVariables)) => + val (violations, validationTiming) = + TimeMeasurement.measure( + queryValidator.validateQuery(schema, queryAst, unmarshalledVariables, errorsLimit)) - QueryReducerExecutor - .reduceQuery( - schema, - queryReducers, - exceptionHandler, - fieldCollector, - valueCollector, - unmarshalledVariables, - tpe, - fields, - userContext) - .map { case (newCtx, timing) => - new PreparedQuery[Ctx, Root, Input]( - queryAst, + if (violations.nonEmpty) + Future.failed(ValidationError(violations, exceptionHandler)) + else { + val executionResult = for { + tpe <- Executor.getOperationRootType( + schema, + exceptionHandler, operation, - tpe, - newCtx, - root, - preparedFields, - (c: Ctx, r: Root, m: ResultMarshaller, scheme: ExecutionScheme) => - executeOperation( + queryAst.sourceMapper) + fieldCollector = new FieldCollector[Ctx, Root]( + schema, + queryAst, + unmarshalledVariables, + queryAst.sourceMapper, + valueCollector, + exceptionHandler) + fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) + } yield { + val preparedFields = fields.fields.flatMap { + case CollectedField(_, astField, Success(_)) => + val allFields = + tpe.getField(schema, astField.name).asInstanceOf[Vector[Field[Ctx, Root]]] + val field = allFields.head + val args = valueCollector.getFieldArgumentValues( + ExecutionPath.empty.add(astField, tpe), + Some(astField), + field.arguments, + astField.arguments, + unmarshalledVariables) + + args.toOption.map(PreparedField(field, _)) + case _ => None + } + + QueryReducerExecutor + .reduceQuery( + schema, + queryReducers, + exceptionHandler, + fieldCollector, + valueCollector, + unmarshalledVariables, + tpe, + fields, + userContext) + .map { case (newCtx, timing) => + new PreparedQuery[Ctx, Root, Input]( queryAst, - operationName, - variables, - um, operation, - queryAst.sourceMapper, - valueCollector, - fieldCollector, - m, - unmarshalledVariables, tpe, - fields, - c, - r, - scheme, - validationTiming, - timing + newCtx, + root, + preparedFields, + (c: Ctx, r: Root, m: ResultMarshaller, scheme: ExecutionScheme) => + executeOperation( + queryAst, + operationName, + variables, + um, + operation, + queryAst.sourceMapper, + valueCollector, + fieldCollector, + m, + unmarshalledVariables, + tpe, + fields, + c, + r, + scheme, + validationTiming, + timing + ) ) - ) + } } - } - executionResult match { - case Success(future) => future - case Failure(error) => Future.failed(error) - } + executionResult match { + case Success(future) => future + case Failure(error) => Future.failed(error) + } + } + } } @@ -143,82 +159,98 @@ case class Executor[Ctx, Root]( marshaller: ResultMarshaller, um: InputUnmarshaller[Input], scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = { - val (violations, validationTiming) = - TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit)) - if (violations.nonEmpty) - scheme.failed(ValidationError(violations, exceptionHandler)) - else { - val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) - val valueCollector = new ValueCollector[Ctx, Input]( - schema, - variables, - queryAst.sourceMapper, - deprecationTracker, - userContext, - exceptionHandler, + val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) + val valueCollector = new ValueCollector[Ctx, Input]( + schema, + variables, + queryAst.sourceMapper, + deprecationTracker, + userContext, + exceptionHandler, + scalarMiddleware, + false)(um) + + val operationCtx = for { + operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) + unmarshalledVariables <- valueCollector.getVariableValues( + operation.variables, scalarMiddleware, - false)(um) + errorsLimit + ) + } yield (operation, unmarshalledVariables) - val executionResult = for { - operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) - unmarshalledVariables <- valueCollector.getVariableValues( - operation.variables, - scalarMiddleware, - errorsLimit - ) - fieldCollector = new FieldCollector[Ctx, Root]( - schema, - queryAst, - unmarshalledVariables, - queryAst.sourceMapper, - valueCollector, - exceptionHandler) - tpe <- Executor.getOperationRootType( - schema, - exceptionHandler, - operation, - queryAst.sourceMapper) - fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) - } yield { - val reduced = QueryReducerExecutor.reduceQuery( - schema, - queryReducers, - exceptionHandler, - fieldCollector, - valueCollector, - unmarshalledVariables, - tpe, - fields, - userContext) - scheme.flatMapFuture(reduced) { case (newCtx, timing) => - executeOperation( - queryAst, - operationName, - variables, - um, - operation, - queryAst.sourceMapper, - valueCollector, - fieldCollector, - marshaller, - unmarshalledVariables, - tpe, - fields, - newCtx, - root, - scheme, - validationTiming, - timing - ) - } - } + operationCtx match { + case Failure(error) => + // return validation errors without variables first if variables is what failed + val violations = queryValidator.validateQuery(schema, queryAst, Map.empty, errorsLimit) + if (violations.nonEmpty) + scheme.failed(ValidationError(violations, exceptionHandler)) + else + scheme.failed(error) + case Success((operation, unmarshalledVariables)) => + val (violations, validationTiming) = + TimeMeasurement.measure( + queryValidator.validateQuery(schema, queryAst, unmarshalledVariables, errorsLimit)) - executionResult match { - case Success(result) => result - case Failure(error) => scheme.failed(error) - } + if (violations.nonEmpty) + scheme.failed(ValidationError(violations, exceptionHandler)) + else { + val executionResult = for { + tpe <- Executor.getOperationRootType( + schema, + exceptionHandler, + operation, + queryAst.sourceMapper) + fieldCollector = new FieldCollector[Ctx, Root]( + schema, + queryAst, + unmarshalledVariables, + queryAst.sourceMapper, + valueCollector, + exceptionHandler) + fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) + } yield { + val reduced = QueryReducerExecutor.reduceQuery( + schema, + queryReducers, + exceptionHandler, + fieldCollector, + valueCollector, + unmarshalledVariables, + tpe, + fields, + userContext) + scheme.flatMapFuture(reduced) { case (newCtx, timing) => + executeOperation( + queryAst, + operationName, + variables, + um, + operation, + queryAst.sourceMapper, + valueCollector, + fieldCollector, + marshaller, + unmarshalledVariables, + tpe, + fields, + newCtx, + root, + scheme, + validationTiming, + timing + ) + } + } + + executionResult match { + case Success(result) => result + case Failure(error) => scheme.failed(error) + } + } } + } private def executeOperation[Input]( diff --git a/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala b/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala index a70ba15e..ab2453ca 100644 --- a/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala +++ b/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala @@ -2,6 +2,7 @@ package sangria.execution import sangria.ast.{AstVisitor, InputDocument, VariableDefinition} import sangria.ast +import sangria.execution import sangria.marshalling.{FromInput, InputUnmarshaller} import sangria.renderer.SchemaRenderer import sangria.schema._ @@ -28,16 +29,30 @@ case class InputDocumentMaterializer[Vars]( None, false)(iu) - val violations = QueryValidator.default.validateInputDocument(schema, document, inputType) + val variableDefinitions = inferVariableDefinitions(document, inputType) - if (violations.nonEmpty) - Failure(InputDocumentMaterializationError(violations, ExceptionHandler.empty)) - else { - val variableDefinitions = inferVariableDefinitions(document, inputType) + collector.getVariableValues(variableDefinitions, None) match { + case Failure(e) => + // return validation errors without variables first if variables is what failed + val violations = + QueryValidator.default.validateInputDocument( + schema, + document, + inputType, + Map.empty[String, execution.VariableValue] + ) - collector.getVariableValues(variableDefinitions, None) match { - case Failure(e) => Failure(e) - case Success(vars) => + if (violations.nonEmpty) + Failure(InputDocumentMaterializationError(violations, ExceptionHandler.empty)) + else + Failure(e) + case Success(vars) => + val violations = + QueryValidator.default.validateInputDocument(schema, document, inputType, vars) + + if (violations.nonEmpty) + Failure(InputDocumentMaterializationError(violations, ExceptionHandler.empty)) + else { try Success(document.values.flatMap { value => collector.coercionHelper.coerceInputValue( @@ -56,7 +71,7 @@ case class InputDocumentMaterializer[Vars]( catch { case NonFatal(e) => Failure(e) } - } + } } } diff --git a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala index 258ec369..d6d3dc82 100644 --- a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala +++ b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala @@ -22,7 +22,7 @@ object QueryReducerExecutor { middleware: List[Middleware[Ctx]] = Nil, errorsLimit: Option[Int] = None )(implicit executionContext: ExecutionContext): Future[(Ctx, TimeMeasurement)] = { - val violations = queryValidator.validateQuery(schema, queryAst, errorsLimit) + val violations = queryValidator.validateQuery(schema, queryAst, Map.empty, errorsLimit) if (violations.nonEmpty) Future.failed(ValidationError(violations, exceptionHandler)) diff --git a/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala b/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala index 795fe07e..7fba78be 100644 --- a/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala +++ b/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala @@ -101,7 +101,10 @@ object BatchExecutor { inferVariableDefinitions, exceptionHandler)) .flatMap { case res @ (updatedDocument, _) => - val violations = queryValidator.validateQuery(schema, updatedDocument, errorsLimit) + // we're not going to pass variables here, as we call validateQuery again on + // executeIndividual which has the unmarshalled variables at that point + val violations = + queryValidator.validateQuery(schema, updatedDocument, Map.empty, errorsLimit) if (violations.nonEmpty) Failure(ValidationError(violations, exceptionHandler)) else Success(res) diff --git a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala index 7d711e68..c0549829 100644 --- a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala +++ b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala @@ -63,7 +63,8 @@ class ResolverBasedAstSchemaBuilder[Ctx](val resolvers: Seq[AstSchemaResolver[Ct schema: ast.Document, validator: QueryValidator = ResolverBasedAstSchemaBuilder.validator, errorsLimit: Option[Int] = None): Vector[Violation] = - allowKnownDynamicDirectives(validator.validateQuery(validationSchema, schema, errorsLimit)) + allowKnownDynamicDirectives( + validator.validateQuery(validationSchema, schema, Map.empty, errorsLimit)) def validateSchemaWithException( schema: ast.Document, diff --git a/modules/core/src/main/scala/sangria/schema/SchemaComparator.scala b/modules/core/src/main/scala/sangria/schema/SchemaComparator.scala index d356ee68..ac6b8b84 100644 --- a/modules/core/src/main/scala/sangria/schema/SchemaComparator.scala +++ b/modules/core/src/main/scala/sangria/schema/SchemaComparator.scala @@ -280,7 +280,11 @@ object SchemaComparator { val directiveChanges = findInAstDirs( oldType.astDirectives, newType.astDirectives, - added = SchemaChange.InputObjectTypeAstDirectiveAdded(newType, _), + added = d => + SchemaChange.InputObjectTypeAstDirectiveAdded( + newType, + d, + breaking = d.name == OneOfDirective.name), removed = SchemaChange.InputObjectTypeAstDirectiveRemoved(newType, _) ) @@ -900,9 +904,9 @@ object SchemaChange { abstract class AbstractAstDirectiveAdded( val description: String, - val location: DirectiveLocation.Value) + val location: DirectiveLocation.Value, + val breakingChange: Boolean) extends AstDirectiveAdded { - val breakingChange = false val dangerousChange = false } @@ -920,7 +924,9 @@ object SchemaChange { directive: ast.Directive) extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact(directive)}` added on a field `${tpe.name}.${field.name}`", - DirectiveLocation.FieldDefinition) + DirectiveLocation.FieldDefinition, + breakingChange = false + ) case class FieldAstDirectiveRemoved( tpe: ObjectLikeType[_, _], @@ -937,7 +943,9 @@ object SchemaChange { extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact( directive)}` added on an enum value `${tpe.name}.${value.name}`", - DirectiveLocation.EnumValue) + DirectiveLocation.EnumValue, + breakingChange = false + ) case class EnumValueAstDirectiveRemoved( tpe: EnumType[_], @@ -955,7 +963,9 @@ object SchemaChange { extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact( directive)}` added on an input field `${tpe.name}.${field.name}`", - DirectiveLocation.InputFieldDefinition) + DirectiveLocation.InputFieldDefinition, + breakingChange = false + ) case class InputFieldAstDirectiveRemoved( tpe: InputObjectType[_], @@ -974,7 +984,8 @@ object SchemaChange { extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact( directive)}` added on a directive argument `${dir.name}.${argument.name}`", - DirectiveLocation.ArgumentDefinition + DirectiveLocation.ArgumentDefinition, + breakingChange = false ) case class DirectiveArgumentAstDirectiveRemoved( @@ -995,7 +1006,8 @@ object SchemaChange { extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact( directive)}` added on a field argument `${tpe.name}.${field.name}[${argument.name}]`", - DirectiveLocation.ArgumentDefinition + DirectiveLocation.ArgumentDefinition, + breakingChange = false ) case class FieldArgumentAstDirectiveRemoved( @@ -1012,7 +1024,8 @@ object SchemaChange { case class ObjectTypeAstDirectiveAdded(tpe: ObjectType[_, _], directive: ast.Directive) extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact(directive)}` added on an object type `${tpe.name}`", - DirectiveLocation.Object) + DirectiveLocation.Object, + breakingChange = false) case class ObjectTypeAstDirectiveRemoved(tpe: ObjectType[_, _], directive: ast.Directive) extends AbstractAstDirectiveRemoved( @@ -1022,7 +1035,8 @@ object SchemaChange { case class InterfaceTypeAstDirectiveAdded(tpe: InterfaceType[_, _], directive: ast.Directive) extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact(directive)}` added on an interface type `${tpe.name}`", - DirectiveLocation.Interface) + DirectiveLocation.Interface, + breakingChange = false) case class InterfaceTypeAstDirectiveRemoved(tpe: InterfaceType[_, _], directive: ast.Directive) extends AbstractAstDirectiveRemoved( @@ -1032,7 +1046,8 @@ object SchemaChange { case class UnionTypeAstDirectiveAdded(tpe: UnionType[_], directive: ast.Directive) extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact(directive)}` added on a union type `${tpe.name}`", - DirectiveLocation.Union) + DirectiveLocation.Union, + breakingChange = false) case class UnionTypeAstDirectiveRemoved(tpe: UnionType[_], directive: ast.Directive) extends AbstractAstDirectiveRemoved( @@ -1042,7 +1057,8 @@ object SchemaChange { case class EnumTypeAstDirectiveAdded(tpe: EnumType[_], directive: ast.Directive) extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact(directive)}` added on an enum type `${tpe.name}`", - DirectiveLocation.Enum) + DirectiveLocation.Enum, + breakingChange = false) case class EnumTypeAstDirectiveRemoved(tpe: EnumType[_], directive: ast.Directive) extends AbstractAstDirectiveRemoved( @@ -1052,18 +1068,14 @@ object SchemaChange { case class ScalarTypeAstDirectiveAdded(tpe: ScalarType[_], directive: ast.Directive) extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact(directive)}` added on a scalar type `${tpe.name}`", - DirectiveLocation.Scalar) + DirectiveLocation.Scalar, + breakingChange = false) case class ScalarTypeAstDirectiveRemoved(tpe: ScalarType[_], directive: ast.Directive) extends AbstractAstDirectiveRemoved( s"Directive `${QueryRenderer.renderCompact(directive)}` removed from a scalar type `${tpe.name}`", DirectiveLocation.Scalar) - case class InputObjectTypeAstDirectiveAdded(tpe: InputObjectType[_], directive: ast.Directive) - extends AbstractAstDirectiveAdded( - s"Directive `${QueryRenderer.renderCompact(directive)}` added on an input type `${tpe.name}`", - DirectiveLocation.InputObject) - case class InputObjectTypeAstDirectiveRemoved(tpe: InputObjectType[_], directive: ast.Directive) extends AbstractAstDirectiveRemoved( s"Directive `${QueryRenderer.renderCompact(directive)}` removed from an input type `${tpe.name}`", @@ -1072,7 +1084,8 @@ object SchemaChange { case class SchemaAstDirectiveAdded(schema: Schema[_, _], directive: ast.Directive) extends AbstractAstDirectiveAdded( s"Directive `${QueryRenderer.renderCompact(directive)}` added on a schema", - DirectiveLocation.Schema) + DirectiveLocation.Schema, + breakingChange = false) case class SchemaAstDirectiveRemoved(schema: Schema[_, _], directive: ast.Directive) extends AbstractAstDirectiveRemoved( @@ -1081,6 +1094,15 @@ object SchemaChange { // May be a breaking change + case class InputObjectTypeAstDirectiveAdded( + tpe: InputObjectType[_], + directive: ast.Directive, + breaking: Boolean) + extends AbstractAstDirectiveAdded( + s"Directive `${QueryRenderer.renderCompact(directive)}` added on an input type `${tpe.name}`", + DirectiveLocation.InputObject, + breakingChange = breaking) + case class InputFieldAdded(tpe: InputObjectType[_], field: InputField[_], breaking: Boolean) extends AbstractChange( s"Input field `${field.name}` was added to `${tpe.name}` type", diff --git a/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala b/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala index 3cbb6cd7..f2020b42 100644 --- a/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala +++ b/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala @@ -33,7 +33,9 @@ object SchemaValidationRule { ContainerMembersValidator, ValidNamesValidator, IntrospectionNamesValidator, - InputObjectTypeRecursionValidator) + InputObjectTypeRecursionValidator, + OneOfInputObjectValidator + ) val default: List[SchemaValidationRule] = List( DefaultValuesValidationRule, @@ -609,6 +611,32 @@ object EnumValueReservedNameValidator extends SchemaElementValidator { else Vector.empty } +object OneOfInputObjectValidator extends SchemaElementValidator { + override def validateInputObjectType( + schema: Schema[_, _], + tpe: InputObjectType[_] + ): Vector[Violation] = if (tpe.astDirectives.exists(_.name == OneOfDirective.name)) + tpe.fields.iterator.flatMap { field => + val defaultValueError = + field.defaultValue.map(_ => OneOfDefaultValueField(field.name, tpe.name, None, List.empty)) + + val nonOptionalError = if (field.fieldType.isOptional) { + None + } else { + Some( + OneOfMandatoryField( + field.name, + tpe.name, + None, + List.empty + ) + ) + } + Iterator(defaultValueError, nonOptionalError).flatten + }.toVector + else Vector.empty +} + object InputObjectTypeRecursionValidator extends SchemaElementValidator { override def validateInputObjectType( schema: Schema[_, _], diff --git a/modules/core/src/main/scala/sangria/schema/package.scala b/modules/core/src/main/scala/sangria/schema/package.scala index c53edd74..e540a4ec 100644 --- a/modules/core/src/main/scala/sangria/schema/package.scala +++ b/modules/core/src/main/scala/sangria/schema/package.scala @@ -267,8 +267,18 @@ package object schema { shouldInclude = ctx => !ctx.arg(IfArg) ) + val OneOfDirective: Directive = Directive( + "oneOf", + description = + Some("Indicates exactly one field must be supplied and this field must not be `null`."), + arguments = List.empty, + locations = Set( + DirectiveLocation.InputObject + ) + ) + val BuiltinDirectives: List[Directive] = - IncludeDirective :: SkipDirective :: DeprecatedDirective :: Nil + IncludeDirective :: SkipDirective :: DeprecatedDirective :: OneOfDirective :: Nil val BuiltinDirectivesByName: Map[String, Directive] = BuiltinDirectives.groupBy(_.name).map { case (k, v) => (k, v.head) } diff --git a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala index b3ff0ca6..7e776e5b 100644 --- a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala +++ b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala @@ -3,6 +3,7 @@ package sangria.validation import sangria.ast import sangria.ast.AstVisitorCommand._ import sangria.ast.{AstVisitor, AstVisitorCommand, SourceMapper} +import sangria.execution import sangria.renderer.SchemaRenderer import sangria.schema._ import sangria.validation.rules._ @@ -14,6 +15,7 @@ trait QueryValidator { def validateQuery( schema: Schema[_, _], queryAst: ast.Document, + variableValues: Map[String, execution.VariableValue], errorsLimit: Option[Int]): Vector[Violation] } @@ -45,7 +47,8 @@ object QueryValidator { new VariablesAreInputTypes, new VariablesInAllowedPosition, new InputDocumentNonConflictingVariableInference, - new SingleFieldSubscriptions + new SingleFieldSubscriptions, + new ExactlyOneOfFieldGiven ) def ruleBased(rules: List[ValidationRule]): RuleBasedQueryValidator = @@ -55,6 +58,7 @@ object QueryValidator { def validateQuery( schema: Schema[_, _], queryAst: ast.Document, + variableValues: Map[String, execution.VariableValue], errorsLimit: Option[Int]): Vector[Violation] = Vector.empty } @@ -65,12 +69,15 @@ class RuleBasedQueryValidator(rules: List[ValidationRule]) extends QueryValidato def validateQuery( schema: Schema[_, _], queryAst: ast.Document, - errorsLimit: Option[Int]): Vector[Violation] = { + variables: Map[String, execution.VariableValue], + errorsLimit: Option[Int] + ): Vector[Violation] = { val ctx = new ValidationContext( schema, queryAst, queryAst.sourceMapper, new TypeInfo(schema), + variables, errorsLimit) validateUsingRules(queryAst, ctx, rules.map(_.visitor(ctx)), topLevel = true) @@ -81,9 +88,11 @@ class RuleBasedQueryValidator(rules: List[ValidationRule]) extends QueryValidato def validateInputDocument( schema: Schema[_, _], doc: ast.InputDocument, - inputTypeName: String): Vector[Violation] = + inputTypeName: String, + variables: Map[String, execution.VariableValue] + ): Vector[Violation] = schema.getInputType(ast.NamedType(inputTypeName)) match { - case Some(it) => validateInputDocument(schema, doc, it) + case Some(it) => validateInputDocument(schema, doc, it, variables) case None => throw new IllegalStateException( s"Can't find input type '$inputTypeName' in the schema. Known input types are: ${schema.inputTypes.keys.toVector.sorted @@ -93,10 +102,18 @@ class RuleBasedQueryValidator(rules: List[ValidationRule]) extends QueryValidato def validateInputDocument( schema: Schema[_, _], doc: ast.InputDocument, - inputType: InputType[_]): Vector[Violation] = { + inputType: InputType[_], + variables: Map[String, execution.VariableValue] + ): Vector[Violation] = { val typeInfo = new TypeInfo(schema, Some(inputType)) - val ctx = ValidationContext(schema, ast.Document.emptyStub, doc.sourceMapper, typeInfo) + val ctx = ValidationContext( + schema, + ast.Document.emptyStub, + doc.sourceMapper, + typeInfo, + variables + ) validateUsingRules(doc, ctx, rules.map(_.visitor(ctx)), topLevel = true) @@ -163,6 +180,7 @@ class ValidationContext( val doc: ast.Document, val sourceMapper: Option[SourceMapper], val typeInfo: TypeInfo, + val variables: Map[String, execution.VariableValue], errorsLimit: Option[Int]) { // Using mutable data-structures and mutability to minimize validation footprint @@ -193,9 +211,10 @@ object ValidationContext { schema: Schema[_, _], doc: ast.Document, sourceMapper: Option[SourceMapper], - typeInfo: TypeInfo + typeInfo: TypeInfo, + variables: Map[String, execution.VariableValue] ): ValidationContext = - new ValidationContext(schema, doc, sourceMapper, typeInfo, None) + new ValidationContext(schema, doc, sourceMapper, typeInfo, variables, None) @deprecated( "The validations are now implemented as a part of `ValuesOfCorrectType` validation.", diff --git a/modules/core/src/main/scala/sangria/validation/Violation.scala b/modules/core/src/main/scala/sangria/validation/Violation.scala index 988a5aeb..5d2b79c4 100644 --- a/modules/core/src/main/scala/sangria/validation/Violation.scala +++ b/modules/core/src/main/scala/sangria/validation/Violation.scala @@ -1105,6 +1105,33 @@ case class NoQueryTypeViolation(sourceMapper: Option[SourceMapper], locations: L "Must provide schema definition with query type or a type named Query." } +case class NotExactlyOneOfField( + typeName: String, + sourceMapper: Option[SourceMapper], + locations: List[AstLocation] +) extends AstNodeViolation { + lazy val simpleErrorMessage = s"Exactly one key must be specified for oneOf type '${typeName}'." +} + +case class OneOfMandatoryField( + fieldName: String, + typeName: String, + sourceMapper: Option[SourceMapper], + locations: List[AstLocation] +) extends AstNodeViolation { + lazy val simpleErrorMessage = s"oneOf input field '${typeName}.${fieldName}' must be nullable." +} + +case class OneOfDefaultValueField( + fieldName: String, + typeName: String, + sourceMapper: Option[SourceMapper], + locations: List[AstLocation] +) extends AstNodeViolation { + lazy val simpleErrorMessage = + s"oneOf input field '${typeName}.${fieldName}' cannot have a default value." +} + case class NonUniqueTypeDefinitionViolation( typeName: String, sourceMapper: Option[SourceMapper], diff --git a/modules/core/src/main/scala/sangria/validation/rules/ExactlyOneOfFieldGiven.scala b/modules/core/src/main/scala/sangria/validation/rules/ExactlyOneOfFieldGiven.scala new file mode 100644 index 00000000..9fb90ed7 --- /dev/null +++ b/modules/core/src/main/scala/sangria/validation/rules/ExactlyOneOfFieldGiven.scala @@ -0,0 +1,100 @@ +package sangria.validation.rules + +import sangria.ast +import sangria.execution +import sangria.execution.Trinary.{Defined, NullWithDefault} +import sangria.schema +import sangria.ast.AstVisitorCommand +import sangria.validation._ +import sangria.marshalling.CoercedScalaResultMarshaller + +/** For oneOf input objects, exactly one field should be non-null. */ +class ExactlyOneOfFieldGiven extends ValidationRule { + private val marshaller = CoercedScalaResultMarshaller.default + private val oneOfDirectiveName = schema.OneOfDirective.name + + private def hasOneOfDirective(inputObject: schema.InputObjectType[_]) = + inputObject.astDirectives.exists(_.name == oneOfDirectiveName) + + private def getResolvedVariableValue( + name: String, + inputType: schema.InputType[_], + variableValues: Map[String, execution.VariableValue] + ): Option[Any] = { + val variableValue = variableValues.get(name) + + variableValue.map(_.resolve(marshaller, marshaller, inputType)) match { + case Some(Right(Defined(resolved))) => Some(resolved) + case Some(Right(NullWithDefault(resolved))) => Some(resolved) + case _ => None + } + } + + private def visitNode( + ctx: ValidationContext, + inputType: Option[schema.InputType[_]], + node: Either[ast.ObjectValue, ast.VariableValue] + ): Either[Vector[Violation], AstVisitorCommand.Value] = + inputType.fold(AstVisitorCommand.RightContinue) { inputType => + inputType.namedInputType match { + case namedInputType: schema.InputObjectType[_] if hasOneOfDirective(namedInputType) => + val (allFields, nonNullFields) = node match { + case Left(ast.ObjectValue(fields, _, _)) => + val nonNullFields = fields.filter { field => + field.value match { + case ast.NullValue(_, _) => false + case ast.VariableValue(name, _, _) => + val fieldInputType = namedInputType.fieldsByName + .get(field.name) + .map(_.fieldType) + + fieldInputType.forall { fieldInputType => + getResolvedVariableValue(name, fieldInputType, ctx.variables).isDefined + } + case _ => true + } + } + (fields, nonNullFields) + + case Right(ast.VariableValue(name, _, _)) => + val variableValue = getResolvedVariableValue(name, namedInputType, ctx.variables) + + try + variableValue match { + case Some(resolved) => + val variableObj = resolved.asInstanceOf[Map[String, Any]] + val allFields = variableObj.filter { case (key, _) => + namedInputType.fieldsByName.contains(key) + } + val nonNullFields = allFields.filter { case (_, v) => v != None } + (allFields, nonNullFields) + case _ => (Vector.empty, Vector.empty) + } + catch { + // could get this from asInstanceOf failing for unexpected variable type. + // other validation will cover this problem. + case _: Throwable => (Vector.empty, Vector.empty) + } + } + + (allFields.size, nonNullFields.size) match { + case (1, 1) => AstVisitorCommand.RightContinue + case _ => + val pos = node.fold(_.location, _.location) + Left( + Vector( + NotExactlyOneOfField(namedInputType.name, ctx.sourceMapper, pos.toList) + ) + ) + } + case _ => AstVisitorCommand.RightContinue + } + } + + override def visitor(ctx: ValidationContext): AstValidatingVisitor = new AstValidatingVisitor { + override val onEnter: ValidationVisit = { + case node: ast.ObjectValue => visitNode(ctx, ctx.typeInfo.inputType, Left(node)) + case node: ast.VariableValue => visitNode(ctx, ctx.typeInfo.inputType, Right(node)) + } + } +} diff --git a/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala b/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala index 30d7fc17..cc3ad33d 100644 --- a/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala +++ b/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala @@ -2,6 +2,7 @@ package sangria.execution import sangria.macros._ import sangria.ast +import sangria.execution import sangria.marshalling.ScalaInput.scalaInput import sangria.marshalling.sprayJson._ import sangria.parser.QueryParser @@ -111,7 +112,11 @@ class InputDocumentMaterializerSpec extends AnyWordSpec with Matchers with Strin } """ - val errors = QueryValidator.default.validateInputDocument(schema, inp, "Config") + val errors = QueryValidator.default.validateInputDocument( + schema, + inp, + "Config", + Map.empty[String, execution.VariableValue]) assertViolations( errors, @@ -164,7 +169,11 @@ class InputDocumentMaterializerSpec extends AnyWordSpec with Matchers with Strin } """ - val errors = QueryValidator.default.validateInputDocument(schema, inp, "Config") + val errors = QueryValidator.default.validateInputDocument( + schema, + inp, + "Config", + Map.empty[String, execution.VariableValue]) assertViolations( errors, diff --git a/modules/core/src/test/scala/sangria/introspection/IntrospectionSpec.scala b/modules/core/src/test/scala/sangria/introspection/IntrospectionSpec.scala index 7e06e910..4ebb8793 100644 --- a/modules/core/src/test/scala/sangria/introspection/IntrospectionSpec.scala +++ b/modules/core/src/test/scala/sangria/introspection/IntrospectionSpec.scala @@ -790,6 +790,13 @@ class IntrospectionSpec extends AnyWordSpec with Matchers with FutureResultSuppo "defaultValue" -> "\"No longer supported\"" )), "isRepeatable" -> false + ), + Map( + "name" -> "oneOf", + "description" -> "Indicates exactly one field must be supplied and this field must not be `null`.", + "locations" -> Vector("INPUT_OBJECT"), + "args" -> Vector.empty, + "isRepeatable" -> false ) ), "description" -> null diff --git a/modules/core/src/test/scala/sangria/schema/AstSchemaMaterializerSpec.scala b/modules/core/src/test/scala/sangria/schema/AstSchemaMaterializerSpec.scala index e09997dd..7efc21ff 100644 --- a/modules/core/src/test/scala/sangria/schema/AstSchemaMaterializerSpec.scala +++ b/modules/core/src/test/scala/sangria/schema/AstSchemaMaterializerSpec.scala @@ -81,11 +81,12 @@ class AstSchemaMaterializerSpec val schema = Schema.buildFromAst(ast) - schema.directives should have size 3 + schema.directives should have size 4 (schema.directivesByName("skip") should be).theSameInstanceAs(SkipDirective) (schema.directivesByName("include") should be).theSameInstanceAs(IncludeDirective) (schema.directivesByName("deprecated") should be).theSameInstanceAs(DeprecatedDirective) + (schema.directivesByName("oneOf") should be).theSameInstanceAs(OneOfDirective) } "Overriding directives excludes specified" in { @@ -98,6 +99,7 @@ class AstSchemaMaterializerSpec directive @skip on FIELD directive @include on FIELD directive @deprecated on FIELD_DEFINITION + directive @oneOf on FIELD_DEFINITION type Hello { str: String @@ -106,12 +108,13 @@ class AstSchemaMaterializerSpec val schema = Schema.buildFromAst(ast) - schema.directives should have size 3 + schema.directives should have size 4 // We don't allow to override the built-in directives, since it's too dangerous (schema.directivesByName("skip") should be).theSameInstanceAs(SkipDirective) (schema.directivesByName("include") should be).theSameInstanceAs(IncludeDirective) (schema.directivesByName("deprecated") should be).theSameInstanceAs(DeprecatedDirective) + (schema.directivesByName("oneOf") should be).theSameInstanceAs(OneOfDirective) } "Adding directives maintains built-in one" in { @@ -130,11 +133,12 @@ class AstSchemaMaterializerSpec val schema = Schema.buildFromAst(ast) - schema.directives should have size 4 + schema.directives should have size 5 (schema.directivesByName("skip") should be).theSameInstanceAs(SkipDirective) (schema.directivesByName("include") should be).theSameInstanceAs(IncludeDirective) (schema.directivesByName("deprecated") should be).theSameInstanceAs(DeprecatedDirective) + (schema.directivesByName("oneOf") should be).theSameInstanceAs(OneOfDirective) } "Type modifiers" in { @@ -577,6 +581,44 @@ class AstSchemaMaterializerSpec error.getMessage should include("Must provide only one mutation type in schema.") } + "Does not allow mandatory fields in oneOf input objects" in { + val ast = graphql""" + type Query { + query(input: OneOfInput!): String + } + + input OneOfInput @oneOf { + foo: String! + bar: Int + } + """ + + val error = intercept[SchemaValidationException](Schema.buildFromAst(ast)) + + error.getMessage should include("oneOf input field 'OneOfInput.foo' must be nullable.") + } + + "Does not allow mandatory fields in oneOf input object extensions" in { + val ast = graphql""" + type Query { + query(input: OneOfInput!): String + } + + input OneOfInput @oneOf { + foo: String + bar: Int + } + + extend input OneOfInput @oneOf { + more: Boolean! + } + """ + + val error = intercept[SchemaValidationException](Schema.buildFromAst(ast)) + + error.getMessage should include("oneOf input field 'OneOfInput.more' must be nullable.") + } + "Allows only a single subscription type" in { val ast = graphql""" diff --git a/modules/core/src/test/scala/sangria/schema/SchemaComparatorSpec.scala b/modules/core/src/test/scala/sangria/schema/SchemaComparatorSpec.scala index 09c7a92c..626e2e7e 100644 --- a/modules/core/src/test/scala/sangria/schema/SchemaComparatorSpec.scala +++ b/modules/core/src/test/scala/sangria/schema/SchemaComparatorSpec.scala @@ -668,6 +668,60 @@ class SchemaComparatorSpec extends AnyWordSpec with Matchers { nonBreakingChange[ScalarTypeAstDirectiveRemoved]( "Directive `@bar(ids:[1,2])` removed from a scalar type `Foo5`") ) + + "detect removal of @oneOf" in checkChangesWithoutQueryType( + gql""" + input UserBy @oneOf { + id: ID + email: String + username: String + registrationNumber: Int + } + type Query { + user(by: UserBy!): String + } + """, + gql""" + input UserBy { + id: ID + email: String + username: String + registrationNumber: Int + } + type Query { + user(by: UserBy!): String + } + """, + nonBreakingChange[InputObjectTypeAstDirectiveRemoved]( + "Directive `@oneOf` removed from an input type `UserBy`") + ) + + "detect add of @oneOf" in checkChangesWithoutQueryType( + gql""" + input UserBy { + id: ID + email: String + username: String + registrationNumber: Int + } + type Query { + user(by: UserBy!): String + } + """, + gql""" + input UserBy @oneOf { + id: ID + email: String + username: String + registrationNumber: Int + } + type Query { + user(by: UserBy!): String + } + """, + breakingChange[InputObjectTypeAstDirectiveAdded]( + "Directive `@oneOf` added on an input type `UserBy`") + ) } private[this] def breakingChange[T: ClassTag](description: String) = diff --git a/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala b/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala index 203db9ce..4cc424a3 100644 --- a/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala +++ b/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala @@ -30,7 +30,8 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, Map.empty, None) should be( + Symbol("empty")) } "Notes that non-existent fields are invalid" in { @@ -42,7 +43,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Requires fields on objects" in { @@ -52,7 +57,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Disallows fields on scalars" in { @@ -66,7 +75,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Disallows object fields on interfaces" in { @@ -79,7 +92,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Allows object fields in fragments" in { @@ -96,7 +113,8 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, Map.empty, None) should be( + Symbol("empty")) } "Allows object fields in inline fragments" in { @@ -111,7 +129,8 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, Map.empty, None) should be( + Symbol("empty")) } } } diff --git a/modules/core/src/test/scala/sangria/util/CatsSupport.scala b/modules/core/src/test/scala/sangria/util/CatsSupport.scala index 60d56f92..573894af 100644 --- a/modules/core/src/test/scala/sangria/util/CatsSupport.scala +++ b/modules/core/src/test/scala/sangria/util/CatsSupport.scala @@ -221,7 +221,7 @@ object CatsScenarioExecutor extends FutureResultSupport { case Validate(rules) => ValidationResult( new RuleBasedQueryValidator(rules.toList) - .validateQuery(`given`.schema, QueryParser.parse(`given`.query).get, None)) + .validateQuery(`given`.schema, QueryParser.parse(`given`.query).get, Map.empty, None)) case Execute(validate, value, vars, op) => val validator = if (validate) QueryValidator.default else QueryValidator.empty diff --git a/modules/core/src/test/scala/sangria/util/ValidationSupport.scala b/modules/core/src/test/scala/sangria/util/ValidationSupport.scala index 7de201fd..4be00974 100644 --- a/modules/core/src/test/scala/sangria/util/ValidationSupport.scala +++ b/modules/core/src/test/scala/sangria/util/ValidationSupport.scala @@ -1,5 +1,7 @@ package sangria.util +import sangria.ast +import sangria.execution import sangria.parser.QueryParser import sangria.schema._ import sangria.validation._ @@ -10,6 +12,8 @@ import org.scalatest.matchers.should.Matchers import sangria.ast.Document import sangria.util.tag.@@ import sangria.marshalling.FromInput.CoercedScalaResult +import sangria.execution.ExceptionHandler +import sangria.execution.ValueCollector trait ValidationSupport extends Matchers { type TestField = Field[Unit, Unit] @@ -145,6 +149,13 @@ trait ValidationSupport extends Matchers { ) ) + val OneOfInput = InputObjectType( + "OneOfInput", + List( + InputField("catName", OptionInputType(StringType)), + InputField("dogId", OptionInputType(IntType)) + )).withDirective(ast.Directive(OneOfDirective.name)) + val ComplicatedArgs = ObjectType( "ComplicatedArgs", List[TestField]( @@ -239,6 +250,7 @@ trait ValidationSupport extends Matchers { val QueryRoot = ObjectType( "QueryRoot", List[TestField]( + Field("foo", OptionType(StringType), resolve = _ => None), Field( "human", OptionType(Human), @@ -251,7 +263,13 @@ trait ValidationSupport extends Matchers { Field("catOrDog", OptionType(CatOrDog), resolve = _ => None), Field("dogOrHuman", OptionType(DogOrHuman), resolve = _ => None), Field("humanOrAlien", OptionType(HumanOrAlien), resolve = _ => None), - Field("complicatedArgs", OptionType(ComplicatedArgs), resolve = _ => None) + Field("complicatedArgs", OptionType(ComplicatedArgs), resolve = _ => None), + Field( + "oneOfQuery", + OptionType(CatOrDog), + arguments = List(Argument("input", OneOfInput)), + resolve = _ => None + ) ) ) @@ -348,10 +366,14 @@ trait ValidationSupport extends Matchers { s: Schema[_, _], rules: List[ValidationRule], query: String, - expectedErrors: Seq[(String, Seq[Pos])]) = { + expectedErrors: Seq[(String, Seq[Pos])], + vars: (String, String) = "" -> "" + ) = { val Success(doc) = QueryParser.parse(query) - assertViolations(validator(rules).validateQuery(s, doc, None), expectedErrors: _*) + val variables = getVariableValues(s, vars) + + assertViolations(validator(rules).validateQuery(s, doc, variables, None), expectedErrors: _*) } def expectInputInvalid( @@ -362,12 +384,48 @@ trait ValidationSupport extends Matchers { typeName: String) = { val Success(doc) = QueryParser.parseInputDocumentWithVariables(query) - assertViolations(validator(rules).validateInputDocument(s, doc, typeName), expectedErrors: _*) + assertViolations( + validator(rules) + .validateInputDocument(s, doc, typeName, Map.empty[String, execution.VariableValue]), + expectedErrors: _*) + } + + private def getVariableValues(s: Schema[_, _], vars: (String, String)) = { + import spray.json._ + import sangria.marshalling.sprayJson._ + + val valueCollector = new ValueCollector( + s, + (if (vars._2.nonEmpty) vars._2 else "{}").parseJson, + None, + None, + (), + ExceptionHandler.empty, + None, + true) + + valueCollector + .getVariableValues( + QueryParser + .parse(s"query Foo${if (vars._1.nonEmpty) "(" + vars._1 + ")" else ""} {foo}") + .get + .operations(Some("Foo")) + .variables, + None) + .get } - def expectValid(s: Schema[_, _], rules: List[ValidationRule], query: String) = { + def expectValid( + s: Schema[_, _], + rules: List[ValidationRule], + query: String, + vars: (String, String) = "" -> "" + ) = { val Success(doc) = QueryParser.parse(query) - val errors = validator(rules).validateQuery(s, doc, None) + + val variables = getVariableValues(s, vars) + + val errors = validator(rules).validateQuery(s, doc, variables, None) withClue(renderViolations(errors)) { errors should have size 0 @@ -382,15 +440,22 @@ trait ValidationSupport extends Matchers { val Success(doc) = QueryParser.parseInputDocumentWithVariables(query) withClue("Should validate") { - validator(rules).validateInputDocument(s, doc, typeName) should have size 0 + validator(rules).validateInputDocument( + s, + doc, + typeName, + Map.empty[String, execution.VariableValue]) should have size 0 } } def expectPassesRule(rule: ValidationRule, query: String) = expectValid(schema, rule :: Nil, query) - def expectPasses(query: String) = - expectValid(schema, defaultRule.get :: Nil, query) + def expectPasses( + query: String, + vars: (String, String) = "" -> "" + ) = + expectValid(schema, defaultRule.get :: Nil, query, vars) def expectInputPasses(typeName: String, query: String) = expectValidInput(schema, defaultRule.get :: Nil, query, typeName) @@ -405,12 +470,18 @@ trait ValidationSupport extends Matchers { query, expectedErrors.map { case (msg, pos) => msg -> pos.toList }) - def expectFails(query: String, expectedErrors: List[(String, Option[Pos])]) = + def expectFails( + query: String, + expectedErrors: List[(String, Option[Pos])], + vars: (String, String) = "" -> "" + ) = expectInvalid( schema, defaultRule.get :: Nil, query, - expectedErrors.map { case (msg, pos) => msg -> pos.toList }) + expectedErrors.map { case (msg, pos) => msg -> pos.toList }, + vars + ) def expectInputFails(typeName: String, query: String, expectedErrors: List[(String, List[Pos])]) = expectInputInvalid(schema, defaultRule.get :: Nil, query, expectedErrors, typeName) @@ -432,7 +503,11 @@ trait ValidationSupport extends Matchers { violationCheck: Violation => Unit): Unit = { val schema = Schema.buildFromAst(initialSchemaDoc) val Success(docUnderTest) = QueryParser.parse(sdlUnderTest) - val violations = validator(v.toList).validateQuery(schema, docUnderTest, None) + val violations = validator(v.toList).validateQuery( + schema, + docUnderTest, + Map.empty[String, execution.VariableValue], + None) violations shouldNot be(empty) violations.size shouldBe 1 violationCheck(violations.head) @@ -451,7 +526,11 @@ trait ValidationSupport extends Matchers { v: Option[ValidationRule]): Unit = { val schema = Schema.buildFromAst(initialSchemaDoc) val Success(docUnderTest) = QueryParser.parse(sdlUnderTest) - val violations = validator(v.toList).validateQuery(schema, docUnderTest, None) + val violations = validator(v.toList).validateQuery( + schema, + docUnderTest, + Map.empty[String, execution.VariableValue], + None) violations shouldBe empty } diff --git a/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala b/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala index d64e81ef..7e48b59b 100644 --- a/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala +++ b/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala @@ -42,7 +42,7 @@ class QueryValidatorSpec extends AnyWordSpec { "not limit number of errors returned if the limit is not provided" in { val Success(doc) = QueryParser.parse(invalidQuery) - val result = validator.validateQuery(schema, doc, None) + val result = validator.validateQuery(schema, doc, Map.empty, None) // 10 errors are expected because there are 5 input objects in the list with 2 missing fields each assertResult(10)(result.length) @@ -51,7 +51,7 @@ class QueryValidatorSpec extends AnyWordSpec { val errorsLimit = 5 val Success(doc) = QueryParser.parse(invalidQuery) - val result = validator.validateQuery(schema, doc, Some(errorsLimit)) + val result = validator.validateQuery(schema, doc, Map.empty, Some(errorsLimit)) assertResult(errorsLimit)(result.length) } diff --git a/modules/core/src/test/scala/sangria/validation/rules/ExactlyOneOfFieldGivenSpec.scala b/modules/core/src/test/scala/sangria/validation/rules/ExactlyOneOfFieldGivenSpec.scala new file mode 100644 index 00000000..64549a04 --- /dev/null +++ b/modules/core/src/test/scala/sangria/validation/rules/ExactlyOneOfFieldGivenSpec.scala @@ -0,0 +1,253 @@ +package sangria.validation.rules + +import sangria.util.{Pos, ValidationSupport} +import org.scalatest.wordspec.AnyWordSpec + +class ExactlyOneOfFieldGivenSpec extends AnyWordSpec with ValidationSupport { + + override val defaultRule = Some(new ExactlyOneOfFieldGiven) + + "Validate: exactly oneOf field given" should { + "pass with exactly one non-null field given" in expectPasses(""" + query OneOfQuery { + oneOfQuery(input: { + catName: "Gretel" + }) { + ... on Cat { + name + } + } + } + """) + + "fail with exactly one null field given" in expectFails( + """ + query OneOfQuery { + oneOfQuery(input: { + catName: null + }) { + ... on Cat { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) + ) + + "fail with no fields given" in expectFails( + """ + query OneOfQuery { + oneOfQuery(input: {}) { + ... on Cat { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) + ) + + "fail with more than one non-null args given" in expectFails( + """ + query OneOfQuery { + oneOfQuery(input: { + catName: "Gretel", + dogId: 123 + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) + ) + + "fail with one non-null arg and one null arg given" in expectFails( + """ + query OneOfQuery { + oneOfQuery(input: { + catName: "Gretel", + dogId: null + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) + ) + + "fail with more than one null args given" in expectFails( + """ + query OneOfQuery { + oneOfQuery(input: { + catName: null, + dogId: null + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) + ) + + "fail with an null arg and non-null variable given" in expectFails( + """ + query OneOfQuery($catName: String) { + oneOfQuery(input: { + catName: $catName, + dogId: null + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$catName: String" -> """{"catName": "Gretel"}""" + ) + + "fail with an non-null arg and null variable given" in expectFails( + """ + query OneOfQuery($catName: String) { + oneOfQuery(input: { + catName: $catName, + dogId: 123 + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$catName: String" -> """{"catName": null}""" + ) + + "fail with a non-null variable and non-null arg given" in expectFails( + """ + query OneOfQuery($catName: String) { + oneOfQuery(input: { + catName: $catName, + dogId: 123 + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$catName: String" -> """{"catName": "Gretel"}""" + ) + + "pass with a variable object with only one non-null value" in expectFails( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$input: OneOfInput!" -> """{"input":{"catName": "Gretel", "dogId": null}}""" + ) + + "fail with a variable object with only null values" in expectFails( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$input: OneOfInput!" -> """{"input":{"catName": null}}""" + ) + + "fail with a variable object with more than one non-null values" in expectFails( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$input: OneOfInput!" -> """{"input":{"catName": "Gretel", "dogId": 123}}""" + ) + + "pass with a variable object with exactly one non-null values" in expectPasses( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + "$input: OneOfInput!" -> """{"input":{"dogId": 123}}""" + ) + + "passes with a variable that has a default value" in expectPasses( + """ + query OneOfQuery($catName: String = "Gretel") { + oneOfQuery(input: { + catName: $catName + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + """$catName: String = "Gretel"""" -> """{}""" + ) + } +}