Skip to content

Commit

Permalink
UT for AzureCC provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lunwang-ttd committed Oct 9, 2023
1 parent de57cb6 commit 144b52e
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 20 deletions.
12 changes: 12 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@
<version>2.10</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>5.2.0</version>
<scope>test</scope>
</dependency>
</dependencies>

<distributionManagement>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import com.uid2.enclave.IAttestationProvider;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;

import java.io.IOException;
import java.net.HttpURLConnection;
Expand All @@ -13,27 +12,34 @@
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

public class AzureCCAttestationProvider implements IAttestationProvider {
private final String maaEndpoint;
private static final String DefaultMaaEndpoint = "sharedeus.eus.attest.azure.net";
public static final String DefaultMaaEndpoint = "sharedeus.eus.attest.azure.net";

private final String skrEndpoint;
private static final String DefaultSkrEndpoint = "http://localhost:8080/attest/maa";
public static final String DefaultSkrEndpoint = "http://localhost:8080/attest/maa";

private final HttpClient httpClient;
private String location;

public AzureCCAttestationProvider() {
this(DefaultSkrEndpoint, DefaultMaaEndpoint, null);
this(DefaultSkrEndpoint, DefaultMaaEndpoint, null, null);
}
public AzureCCAttestationProvider(String maaEndpoint) {
this(maaEndpoint, DefaultSkrEndpoint, null);
this(maaEndpoint, DefaultSkrEndpoint, null, null);
}

public AzureCCAttestationProvider(String maaEndpoint, String skrEndpoint) {
this(maaEndpoint, skrEndpoint, null);
this(maaEndpoint, skrEndpoint, null, null);
}

public AzureCCAttestationProvider(String maaEndpoint, String skrEndpoint, HttpClient httpClient) {
this(maaEndpoint, skrEndpoint, httpClient, null);
}

public AzureCCAttestationProvider(String maaEndpoint, String skrEndpoint, HttpClient httpClient, String location) {
this.maaEndpoint = maaEndpoint;
this.skrEndpoint = skrEndpoint;

Expand All @@ -42,26 +48,29 @@ public AzureCCAttestationProvider(String maaEndpoint, String skrEndpoint, HttpCl
} else {
this.httpClient = HttpClient.newHttpClient();
}

if (location != null) {
this.location = location;
}
}

@Override
public byte[] getAttestationRequest(byte[] publicKey) throws AttestationException {
var base64Encoder = Base64.getEncoder();
var gson = new Gson();

var runtimeData = new HashMap<String, String>();
runtimeData.put("location", getLocation());
runtimeData.put("publicKey", base64Encoder.encodeToString(publicKey));
var runtimeData = Map.of("location", getLocation(), "publicKey", base64Encoder.encodeToString(publicKey));
String runtimeDataJson = gson.toJson(runtimeData);

var body = new HashMap<String, String>();
body.put("maa_endpoint", this.maaEndpoint);
body.put("runtime_data", base64Encoder.encodeToString(runtimeDataJson.getBytes()));
String bodyJson = gson.toJson(body);
var skrRequest = new SkrRequest();
skrRequest.maa_endpoint = this.maaEndpoint;
skrRequest.runtime_data = base64Encoder.encodeToString(runtimeDataJson.getBytes());

String requestBody = gson.toJson(skrRequest);
var request = HttpRequest.newBuilder()
.uri(URI.create(skrEndpoint))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(bodyJson))
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
.build();

try {
Expand All @@ -70,13 +79,15 @@ public byte[] getAttestationRequest(byte[] publicKey) throws AttestationExceptio
throw new AttestationException("Skr failed with status code: " + response.statusCode() + " body: " + response.body());
}

var responseBodyType = new TypeToken<HashMap<String, String>>(){};
var responseBody = gson.fromJson(response.body(), responseBodyType);
var token = responseBody.get("token");
if (token == null) {
var skrResponse = gson.fromJson(response.body(), SkrResponse.class);
if (skrResponse == null) {
throw new AttestationException("response is null");
}

if (skrResponse.token == null || skrResponse.token.isEmpty()) {
throw new AttestationException("token field not exist in Skr response");
}
return token.getBytes();
return skrResponse.token.getBytes();
} catch (IOException e) {
throw new AttestationException(e);
} catch (InterruptedException e) {
Expand All @@ -85,6 +96,20 @@ public byte[] getAttestationRequest(byte[] publicKey) throws AttestationExceptio
}

private String getLocation() throws AttestationException {
if (this.location != null) {
return this.location;
}

// TODO(lun.wang) get location from meta server
return "";
}

private static class SkrRequest {
private String maa_endpoint;
private String runtime_data;
}

private static class SkrResponse {
private String token;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package com.uid2.attestation.azure;

import com.uid2.enclave.AttestationException;

import com.google.gson.Gson;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;

import java.net.HttpURLConnection;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.Map;

public class AzureCCAttestationProviderTest {
@Test
public void testGetAttestationRequestSuccess() throws Exception {
var gson = new Gson();

// Mock response
final var publicTokenMock = new byte[] {0x01, 0x02};
final var maaTokenMock = "abc";
final var httpResponseMock = mock(HttpResponse.class);
when(httpResponseMock.statusCode()).thenReturn(HttpURLConnection.HTTP_OK);
when(httpResponseMock.body()).thenReturn(gson.toJson(Map.of("token", maaTokenMock)));

final var httpClientMock = mock(HttpClient.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);

// Verify output
final var provider = new AzureCCAttestationProvider(AzureCCAttestationProvider.DefaultMaaEndpoint,
AzureCCAttestationProvider.DefaultSkrEndpoint, httpClientMock);
var output = provider.getAttestationRequest(publicTokenMock);
Assert.assertArrayEquals(maaTokenMock.getBytes(), output);

// Verify sent request
var requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(httpClientMock).send(requestCaptor.capture(), any(HttpResponse.BodyHandler.class));
var request = requestCaptor.getValue();
Assert.assertEquals(AzureCCAttestationProvider.DefaultSkrEndpoint, request.uri().toString());
}

@Test
public void testGetAttestationRequestFailure_InvalidStatusCode() throws Exception {
final var publicTokenMock = new byte[] {0x01, 0x02};
final var httpResponseMock = mock(HttpResponse.class);
when(httpResponseMock.statusCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);

final var httpClientMock = mock(HttpClient.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);

final var provider = new AzureCCAttestationProvider(AzureCCAttestationProvider.DefaultMaaEndpoint,
AzureCCAttestationProvider.DefaultSkrEndpoint, httpClientMock);
var thrown = Assert.assertThrows(AttestationException.class, () -> provider.getAttestationRequest(publicTokenMock));
Assert.assertTrue(thrown.getMessage().startsWith("Skr failed with status code: " + HttpURLConnection.HTTP_INTERNAL_ERROR));
}

@Test
public void testGetAttestationRequestFailure_EmptyResponseBody() throws Exception {
final var publicTokenMock = new byte[] {0x01, 0x02};
final var httpResponseMock = mock(HttpResponse.class);
when(httpResponseMock.statusCode()).thenReturn(HttpURLConnection.HTTP_OK);

final var httpClientMock = mock(HttpClient.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);

final var provider = new AzureCCAttestationProvider(AzureCCAttestationProvider.DefaultMaaEndpoint,
AzureCCAttestationProvider.DefaultSkrEndpoint, httpClientMock);
var thrown = Assert.assertThrows(AttestationException.class, () -> provider.getAttestationRequest(publicTokenMock));
Assert.assertEquals("response is null", thrown.getMessage());
}

@Test
public void testGetAttestationRequestFailure_InvalidResponseBody() throws Exception {
var gson = new Gson();
final var publicTokenMock = new byte[] {0x01, 0x02};
final var httpResponseMock = mock(HttpResponse.class);
when(httpResponseMock.statusCode()).thenReturn(HttpURLConnection.HTTP_OK);
when(httpResponseMock.body()).thenReturn(gson.toJson(Map.of("key", 123)));

final var httpClientMock = mock(HttpClient.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);

final var provider = new AzureCCAttestationProvider(AzureCCAttestationProvider.DefaultMaaEndpoint,
AzureCCAttestationProvider.DefaultSkrEndpoint, httpClientMock);
var thrown = Assert.assertThrows(AttestationException.class, () -> provider.getAttestationRequest(publicTokenMock));
Assert.assertEquals("token field not exist in Skr response", thrown.getMessage());
}
}

0 comments on commit 144b52e

Please sign in to comment.