Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSockets OnMessage Parameter Validation Diagnostic #277

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class WebSocketConstants {
public static final String PATHPARAM_DIAGNOSTIC_CODE = "ChangePathParamValue";

public static final String ANNOTATION_VALUE = "value";
public static final String ANNOTATION_DECODER = "decoders";

public static final String URI_SEPARATOR = "/";
public static final String CURLY_BRACE_START = "{";
Expand All @@ -47,13 +48,14 @@ public class WebSocketConstants {
/* Diagnostic codes */
public static final String DIAGNOSTIC_CODE_ON_OPEN_INVALID_PARAMS = "OnOpenChangeInvalidParam";
public static final String DIAGNOSTIC_CODE_ON_CLOSE_INVALID_PARAMS = "OnCloseChangeInvalidParam";
public static final String DIAGNOSTIC_CODE_ON_MESSAGE_INVALID_PARAMS = "OnMessageInvalidMessageParams";

public static final String DIAGNOSTIC_SERVER_ENDPOINT_NO_SLASH = "Server endpoint paths must start with a leading '/'.";
public static final String DIAGNOSTIC_SERVER_ENDPOINT_NOT_LEVEL1 = "Server endpoint paths must be a URI-template (level-1) or a partial URI.";
public static final String DIAGNOSTIC_SERVER_ENDPOINT_RELATIVE = "Server endpoint paths must not contain the sequences '/../', '/./' or '//'.";
public static final String DIAGNOSTIC_SERVER_ENDPOINT_DUPLICATE_VAR = "Server endpoint paths must not use the same variable more than once in a path.";
public static final String DIAGNOSTIC_SERVER_ENDPOINT= "ChangeInvalidServerEndpoint";

/* https://jakarta.ee/specifications/websocket/2.0/websocket-spec-2.0.html#applications */
// Class Level Annotations
public static final String SERVER_ENDPOINT_ANNOTATION = "ServerEndpoint";
Expand All @@ -68,28 +70,74 @@ public class WebSocketConstants {
/* Annotations */
public static final String ON_OPEN = "OnOpen";
public static final String ON_CLOSE = "OnClose";
public static final String ON_MESSAGE = "OnMessage";

public static final String IS_ANNOTATION = "isAnnotation";

/* Types */
public static final String PATH_PARAM_ANNOTATION = "PathParam";

// For OnOpen annotation
/* For OnOpen annotation */
public static final Set<String> ON_OPEN_PARAM_OPT_TYPES= new HashSet<>(Arrays.asList("jakarta.websocket.EndpointConfig", "jakarta.websocket.Session"));
public static final Set<String> RAW_ON_OPEN_PARAM_OPT_TYPES= new HashSet<>(Arrays.asList("EndpointConfig", "Session"));

/* For OnClose annotation */
public static final Set<String> ON_CLOSE_PARAM_OPT_TYPES = new HashSet<>(Arrays.asList("jakarta.websocket.CloseReason", "jakarta.websocket.Session"));
public static final Set<String> RAW_ON_CLOSE_PARAM_OPT_TYPES = new HashSet<>(Arrays.asList("CloseReason", "Session"));
/* For OnMessage annotation */
public static final Set<String> ON_MESSAGE_PARAM_OPT_TYPES = new HashSet<>(Arrays.asList("jakarta.websocket.Session"));
public static final Set<String> RAW_ON_MESSAGE_PARAM_OPT_TYPES = new HashSet<>(Arrays.asList("Session"));
/* For OnMessage (Text) annotation */
public static final Set<String> ON_MESSAGE_TEXT_TYPES = new HashSet<>(Arrays.asList("java.lang.String", "java.io.Reader", "String", "Reader"));
/* For OnMessage (Text) annotation */
public static final Set<String> ON_MESSAGE_BINARY_TYPES = new HashSet<>(Arrays.asList("java.nio.ByteBuffer", "java.io.InputStream", "ByteBuffer", "InputStream"));
/* For OnMessage (Text) annotation */
public static final Set<String> ON_MESSAGE_PONG_TYPES = new HashSet<>(Arrays.asList("jakarta.websocket.PongMessage", "PongMessage"));
Comment on lines +86 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pull the latest from main, I suspect that you'll be able to use the already defined strings for these types.


/* Wrapper Objects */
public static final Set<String> RAW_WRAPPER_OBJS = new HashSet<>(Arrays.asList("String", "Boolean", "Integer", "Long", "Double", "Float"));
public static final Set<String> WRAPPER_OBJS = RAW_WRAPPER_OBJS.stream().map(raw -> "java.lang.".concat(raw)).collect(Collectors.toSet());

public static final String RAW_STRING_TYPE = "String";
public static final String STRING_OBJ = "java.lang.String";

public static final String RAW_BOOLEAN_TYPE = "boolean";
public static final String BOOLEAN_OBJ = "java.lang.Boolean";

public static final String RAW_BYTEBUFFER_OBJ = "ByteBuffer";
public static final String BYTEBUFFER_OBJ = "java.nio.ByteBuffer";

// Messages
public static final String PARAM_TYPE_DIAG_MSG = "Invalid parameter type. When using %s, parameter must be of type: \n- %s\n- annotated with @PathParams and of type String or any Java primitive type or boxed version thereof";

public static final String ONMESSAGE_DUPLICATE_SPECIAL_MSG = "Only one optional parameter of this type is allowed for a method with the @OnMessage annotation.";
public static final String ONMESSAGE_INVALID_PATH_PARAM_MSG = "Only String and Java primitive types are allowed to be annotated with the @PathParam annotation.";
public static final String ONMESSAGE_DUPLICATE_MESSAGE_PARAM_MSG = "Multiple parameters of this type are not allowed for a method with the @OnMessage annotation.";


public static final String TEXT_PARAMS_DIAG_MSG = "Invalid parameter type. OnMessage methods for handling text messages may have the following parameters: \r\n"
+ " - String to receive the whole message\r\n"
+ " - Java primitive or class equivalent to receive the whole message converted to that type\r\n"
+ " - String and boolean pair to receive the message in parts\r\n"
+ " - Reader to receive the whole message as a blocking stream\r\n"
+ " - any object parameter for which the endpoint has a text decoder (Decoder.Text or Decoder.TextStream)";

public static final String BINARY_PARAMS_DIAG_MSG = "Invalid parameter type. OnMessage methods for handling binary messages may have the following parameters: \r\n"
+ " - byte[] or ByteBuffer to receive the whole message\r\n"
+ " - byte[] and boolean pair, or ByteBuffer and boolean pair to receive the message in parts\r\n"
+ " - InputStream to receive the whole message as a blocking stream\r\n"
+ " - any object parameter for which the endpoint has a binary decoder (Decoder.Binary or Decoder.BinaryStream)";

public static final String PONG_PARAMS_DIAG_MSG = "Invalid parameter type. OnMessage methods for handling pong messages may have the following parameters: \r\n"
+ " - PongMessage for handling pong messages";

public static final String INVALID_PARAMS_DIAG_MSG = "Invalid parameter type. Please see @OnMessage API Specification for valid parameter specifications.";

// Enums
public enum MESSAGE_FORMAT {TEXT, BINARY, PONG, INVALID};

/* Regex */
// Check for any URI strings that contain //, /./, or /../
public static final String REGEX_RELATIVE_PATHS = ".*\\/\\.{0,2}\\/.*";
// Check that a URI string is a valid level 1 variable (wrapped in curly brackets): alpha-numeric characters, dash, or a percent encoded character
public static final String REGEX_URI_VARIABLE = "\\{(\\w|-|%20|%21|%23|%24|%25|%26|%27|%28|%29|%2A|%2B|%2C|%2F|%3A|%3B|%3D|%3F|%40|%5B|%5D)+\\}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.ArrayList;

Expand Down Expand Up @@ -83,7 +84,8 @@ public void collectDiagnostics(ICompilationUnit unit, List<Diagnostic> diagnosti
invalidParamsCheck(type, WebSocketConstants.ON_CLOSE, WebSocketConstants.ON_CLOSE_PARAM_OPT_TYPES,
WebSocketConstants.RAW_ON_CLOSE_PARAM_OPT_TYPES,
WebSocketConstants.DIAGNOSTIC_CODE_ON_CLOSE_INVALID_PARAMS, unit, diagnostics);

onMessageParamsCheck(type, unit, diagnostics);

// PathParam URI Mismatch Warning Diagnostic
uriMismatchWarningCheck(type, diagnostics, unit);
// ServerEndpoint annotation diagnostics
Expand Down Expand Up @@ -153,6 +155,164 @@ private void invalidParamsCheck(IType type, String methodAnnotTarget, Set<String
}
}
}

private void onMessageParamsCheck(IType type, ICompilationUnit unit, List<Diagnostic> diagnostics) throws JavaModelException {

boolean endpointDecodersSpecified = checkEndpointHasDecoder(type);

IMethod[] allMethods = type.getMethods();
for (IMethod method : allMethods) {

IAnnotation[] allAnnotations = method.getAnnotations();
for (IAnnotation annotation : allAnnotations) {
if (annotation.getElementName().equals(WebSocketConstants.ON_MESSAGE)) {

Set<String> seenSpecialParams = new HashSet<>();
Map<String, ILocalVariable> seenMessageParams = new HashMap<>();

ILocalVariable[] allParams = method.getParameters();
for (ILocalVariable param : allParams) {

String signature = param.getTypeSignature();
String formatSignature = signature.replace("/", ".");
String resolvedTypeName = JavaModelUtil.getResolvedTypeName(formatSignature, type);
String finalTypeName;
boolean isSpecialType;

IAnnotation[] param_annotations = param.getAnnotations();
boolean hasPathParamAnnot = Arrays.asList(param_annotations).stream().anyMatch(
annot -> annot.getElementName().equals(WebSocketConstants.PATH_PARAM_ANNOTATION));

if (resolvedTypeName != null) {
isSpecialType = WebSocketConstants.ON_MESSAGE_PARAM_OPT_TYPES.contains(resolvedTypeName);
finalTypeName = resolvedTypeName;
} else {
String simpleParamType = Signature.getSignatureSimpleName(signature);
isSpecialType = WebSocketConstants.RAW_ON_MESSAGE_PARAM_OPT_TYPES.contains(simpleParamType);
finalTypeName = simpleParamType;
}

if (isSpecialType) {
if (seenSpecialParams.contains(finalTypeName)){
Diagnostic duplicateSpecialParamDiagnostic = createDiagnostic(param, unit,
WebSocketConstants.ONMESSAGE_DUPLICATE_SPECIAL_MSG,
WebSocketConstants.DIAGNOSTIC_CODE_ON_MESSAGE_INVALID_PARAMS);
diagnostics.add(duplicateSpecialParamDiagnostic);
} else {
seenSpecialParams.add(finalTypeName);
}
} else if (hasPathParamAnnot) {
boolean isPrimitive = JavaModelUtil.isPrimitive(formatSignature);
boolean isPrimWrapped = isWrapper(finalTypeName);
if (!isPrimitive && !isPrimWrapped) {
Diagnostic invalidPathParamDiagnostic = createDiagnostic(param, unit,
WebSocketConstants.ONMESSAGE_INVALID_PATH_PARAM_MSG,
WebSocketConstants.DIAGNOSTIC_CODE_ON_MESSAGE_INVALID_PARAMS);
diagnostics.add(invalidPathParamDiagnostic);
}
} else {
if (seenMessageParams.containsKey(finalTypeName)) {
Diagnostic duplicateMessageParamDiagnostic = createDiagnostic(param, unit,
WebSocketConstants.ONMESSAGE_DUPLICATE_MESSAGE_PARAM_MSG,
WebSocketConstants.DIAGNOSTIC_CODE_ON_MESSAGE_INVALID_PARAMS);
diagnostics.add(duplicateMessageParamDiagnostic);
} else {
seenMessageParams.put(finalTypeName, param);
}
}
}

WebSocketConstants.MESSAGE_FORMAT methodType = null;

Set<String> intersection = new HashSet<>(seenMessageParams.keySet());
Set<String> difference = new HashSet<>(seenMessageParams.keySet());

intersection.retainAll(WebSocketConstants.ON_MESSAGE_TEXT_TYPES);
if (intersection.size() > 0) {
// TEXT Message
methodType = WebSocketConstants.MESSAGE_FORMAT.TEXT;
difference.removeAll(WebSocketConstants.ON_MESSAGE_TEXT_TYPES);
} else {
intersection = new HashSet<>(seenMessageParams.keySet());
intersection.retainAll(WebSocketConstants.ON_MESSAGE_BINARY_TYPES);
if (intersection.size() > 0) {
// BINARY Message
methodType = WebSocketConstants.MESSAGE_FORMAT.BINARY;
difference.removeAll(WebSocketConstants.ON_MESSAGE_BINARY_TYPES);
} else {
intersection = new HashSet<>(seenMessageParams.keySet());
intersection.retainAll(WebSocketConstants.ON_MESSAGE_PONG_TYPES);
if (intersection.size() > 0) {
// PONG Message
methodType = WebSocketConstants.MESSAGE_FORMAT.PONG;
difference.removeAll(WebSocketConstants.ON_MESSAGE_PONG_TYPES);
} else {
// Invalid Message
methodType = WebSocketConstants.MESSAGE_FORMAT.INVALID;
}
}
}

switch (methodType) {
case TEXT:
addDiagnosticsForInvalidMessageParams(difference, seenMessageParams, diagnostics, unit,
true, endpointDecodersSpecified, WebSocketConstants.TEXT_PARAMS_DIAG_MSG);
break;
case BINARY:
addDiagnosticsForInvalidMessageParams(difference, seenMessageParams, diagnostics, unit,
true, endpointDecodersSpecified, WebSocketConstants.BINARY_PARAMS_DIAG_MSG);
break;
case PONG:
addDiagnosticsForInvalidMessageParams(difference, seenMessageParams, diagnostics, unit,
false, false, WebSocketConstants.PONG_PARAMS_DIAG_MSG);
break;
case INVALID:
default:
addDiagnosticsForInvalidMessageParams(difference, seenMessageParams, diagnostics, unit,
false, false, WebSocketConstants.INVALID_PARAMS_DIAG_MSG);
}
}
}
}
}

private void addDiagnosticsForInvalidMessageParams(Set<String> diff, Map<String, ILocalVariable> params,
List<Diagnostic> diagnostics, ICompilationUnit unit, boolean boolAllowed, boolean decodersSpecified, String msg) {
if (!decodersSpecified) {
for (String invalidParam : diff) {
if (boolAllowed && (invalidParam.equals(WebSocketConstants.BOOLEAN_OBJ))) {
continue;
}
ILocalVariable param = params.get(invalidParam);
Diagnostic invalidTextParam = createDiagnostic(param, unit,
msg, WebSocketConstants.DIAGNOSTIC_CODE_ON_MESSAGE_INVALID_PARAMS);
diagnostics.add(invalidTextParam);
}
}
}

/**
* Checks if a WebSocket EndPoint annotation contains custom decoders
*
* @param type representing the class
* @return boolean to represent if decoders are present
* @throws JavaModelException
*/
private boolean checkEndpointHasDecoder(IType type) throws JavaModelException {
IAnnotation[] endpointAnnotations = type.getAnnotations();
for (IAnnotation annotation : endpointAnnotations) {
if (annotation.getElementName().equals(WebSocketConstants.SERVER_ENDPOINT_ANNOTATION)
|| annotation.getElementName().equals(WebSocketConstants.CLIENT_ENDPOINT_ANNOTATION)) {
IMemberValuePair[] valuePairs = annotation.getMemberValuePairs();
for (IMemberValuePair pair : valuePairs) {
if (pair.getMemberName().equals(WebSocketConstants.ANNOTATION_DECODER)) {
return true;
}
}
}
}
return false;
}

/**
* Creates a warning diagnostic if a PathParam annotation does not match any
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package io.openliberty.sample.jakarta.websocket;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like you accidentally added whitespace here.

import java.io.IOException;

import jakarta.websocket.CloseReason;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.openliberty.sample.jakarta.websocket;

import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.nio.ByteBuffer;

import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import jakarta.websocket.OnOpen;
import jakarta.websocket.PongMessage;
import jakarta.websocket.OnError;
import jakarta.websocket.OnMessage;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.OnClose;
import jakarta.websocket.Session;

@ServerEndpoint(value = "/demo/{test}/var/{abcd}")
public class InvalidParamTypeBinary {
private static Session session;
@OnOpen
public void OnOpen(Session session) throws IOException {
System.out.println("Websocket opened: " + session.getId().toString());
}

@OnMessage
public void OnMessage(ByteBuffer bb, PongMessage invalid) throws IOException {
System.out.println("Websocket opened: " + session.getId().toString());
}

@OnClose
public void OnClose(Session session) throws IOException {
System.out.println("WebSocket closed for " + session.getId());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.openliberty.sample.jakarta.websocket;

import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.nio.ByteBuffer;

import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import jakarta.websocket.OnOpen;
import jakarta.websocket.PongMessage;
import jakarta.websocket.OnError;
import jakarta.websocket.OnMessage;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.OnClose;
import jakarta.websocket.Session;

@ServerEndpoint(value = "/demo/{test}/var/{abcd}")
public class InvalidParamTypePong {
private static Session session;

@OnOpen
public void OnOpen(Session session) throws IOException {
System.out.println("Websocket opened: " + session.getId().toString());
}

@OnMessage
public void OnMessage(PongMessage pong, int invalid) throws IOException {
System.out.println("Websocket opened: " + session.getId().toString());
}

@OnClose
public void OnClose(Session session) throws IOException {
System.out.println("WebSocket closed for " + session.getId());
}
}
Loading