diff --git a/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java index 267e03f5c..6ba870b7f 100644 --- a/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java +++ b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java @@ -1,10 +1,26 @@ package com.linkedin.metadata.testing; +import com.linkedin.metadata.testing.annotations.SearchIndexMappings; +import com.linkedin.metadata.testing.annotations.SearchIndexSettings; +import com.linkedin.metadata.testing.annotations.SearchIndexType; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import org.apache.commons.io.IOUtils; import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.AfterEachCallback; import org.junit.jupiter.api.extension.BeforeAllCallback; import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionConfigurationException; import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.platform.commons.util.ClassFilter; +import org.junit.platform.commons.util.ReflectionUtils; /** @@ -14,24 +30,165 @@ */ final class ElasticsearchIntegrationTestExtension implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback { + private static final ExtensionContext.Namespace NAMESPACE = + ExtensionContext.Namespace.create(ElasticsearchIntegrationTestExtension.class); + + private final static String CONTAINER_FACTORY = "containerFactory"; + private final static String CONNECTION = "connection"; + private final static String STATIC_INDICIES = "staticIndicies"; + private final static String INDICIES = "indicies"; @Override - public void afterAll(ExtensionContext context) throws Exception { - // TODO + public void beforeAll(ExtensionContext context) throws Exception { + final Class testClass = context.getTestClass() + .orElseThrow(() -> new ExtensionConfigurationException( + "ElasticSearchIntegrationTestExtension is only supported for classes.")); + + final ExtensionContext.Store store = context.getStore(NAMESPACE); + + final ElasticsearchContainerFactory factory = getContainerFactory(); + final ElasticsearchConnection connection = factory.start(); + + store.put(CONTAINER_FACTORY, factory); + store.put(CONNECTION, connection); + + final List fields = ReflectionUtils.findFields(testClass, field -> { + return ReflectionUtils.isStatic(field) && ReflectionUtils.isPublic(field) && ReflectionUtils.isNotFinal(field) + && SearchIndex.class.isAssignableFrom(field.getType()); + }, ReflectionUtils.HierarchyTraversalMode.TOP_DOWN); + + final SearchIndexFactory indexFactory = new SearchIndexFactory(connection); + final List> indices = createIndices(indexFactory, context.getRequiredTestClass(), fields, + fieldName -> String.format("%s_%s_%s", fieldName, testClass.getSimpleName(), System.currentTimeMillis())); + store.put(STATIC_INDICIES, indices); + } + + private List> createIndices(@Nonnull SearchIndexFactory indexFactory, @Nonnull Object testInstance, + @Nonnull List fields, @Nonnull Function nameFn) throws Exception { + final List> indices = new ArrayList<>(); + + for (Field field : fields) { + final SearchIndexType searchIndexType = field.getAnnotation(SearchIndexType.class); + + if (searchIndexType == null) { + throw new IllegalStateException( + String.format("Field `%s` must be annotated with `SearchIndexType`.", field.getName())); + } + + final String indexName = nameFn.apply(field.getName()).replaceAll("^_*", "").toLowerCase(); + + final SearchIndexSettings settings = field.getAnnotation(SearchIndexSettings.class); + final String settingsJson = settings == null ? null : loadResource(testInstance.getClass(), settings.value()); + + final SearchIndexMappings mappings = field.getAnnotation(SearchIndexMappings.class); + final String mappingsJson = mappings == null ? null : loadResource(testInstance.getClass(), mappings.value()); + + final SearchIndex index = + indexFactory.createIndex(searchIndexType.value(), indexName, settingsJson, mappingsJson); + field.set(testInstance, index); + indices.add(index); + } + + return indices; + } + + private String loadResource(@Nonnull Class testClass, @Nonnull String resource) throws IOException { + final URL resourceUrl = testClass.getResource(resource); + if (resourceUrl == null) { + throw new IllegalArgumentException(String.format("Resource `%s` not found.", resource)); + } + return IOUtils.toString(resourceUrl); + } + + @Nonnull + private Class findContainerFactoryClass() { + final List> classes = ReflectionUtils.findAllClassesInPackage("com.linkedin.metadata.testing", + ClassFilter.of(clazz -> clazz.isAnnotationPresent(ElasticsearchContainerFactory.Implementation.class))); + + if (classes.size() == 0) { + throw new IllegalStateException("Could not find any ElasticsearchContainerFactory implementations."); + } + + if (classes.size() > 1) { + throw new IllegalStateException( + String.format("Found %s ElasticsearchContainerFactory implementations, expected 1. Found %s.", classes.size(), + String.join(", ", classes.stream().map(Class::getSimpleName).collect(Collectors.toList())))); + } + + return classes.get(0); + } + + @Nonnull + private ElasticsearchContainerFactory getContainerFactory() throws Exception { + final Class clazz = findContainerFactoryClass(); + + if (!ElasticsearchContainerFactory.class.isAssignableFrom(clazz)) { + throw new IllegalStateException(String.format( + "Provided class `%s` to ElasticsearchIntegrationTest, but did not inherit from " + + "ElasticsearchContainerFactory.", clazz.toString())); + } + + Constructor constructor; + try { + constructor = clazz.getConstructor(); + } catch (NoSuchMethodException e) { + throw new NoSuchMethodException(String.format( + "Expected ElasticsearchContainerFactory, `%s`, to have a default, public, constructor but found none.", + clazz.toString())); + } + + return (ElasticsearchContainerFactory) constructor.newInstance(); } @Override - public void afterEach(ExtensionContext context) throws Exception { - // TODO + public void beforeEach(ExtensionContext context) throws Exception { + final ExtensionContext.Store store = context.getStore(NAMESPACE); + final ElasticsearchConnection connection = store.get(CONNECTION, ElasticsearchConnection.class); + + final List fields = ReflectionUtils.findFields(context.getRequiredTestClass(), field -> { + return ReflectionUtils.isNotStatic(field) && ReflectionUtils.isPublic(field) && ReflectionUtils.isNotFinal(field) + && SearchIndex.class.isAssignableFrom(field.getType()); + }, ReflectionUtils.HierarchyTraversalMode.TOP_DOWN); + + final SearchIndexFactory indexFactory = new SearchIndexFactory(connection); + final List> indices = createIndices(indexFactory, context.getRequiredTestInstance(), fields, + fieldName -> String.format("%s_%s_%s_%s", fieldName, context.getRequiredTestMethod().getName(), + context.getRequiredTestClass().getSimpleName(), System.currentTimeMillis())); + store.put(INDICIES, indices); } + @SuppressWarnings("unchecked") @Override - public void beforeAll(ExtensionContext context) throws Exception { - // TODO + public void afterAll(ExtensionContext context) throws Exception { + final ExtensionContext.Store store = context.getStore(NAMESPACE); + + final List> indices = (List>) store.get(STATIC_INDICIES, List.class); + final ElasticsearchConnection connection = store.get(CONNECTION, ElasticsearchConnection.class); + + cleanUp(connection, indices); + + // don't need to close the factory since it implements CloseableResource, junit will close it since it is in the + // store } + @SuppressWarnings("unchecked") @Override - public void beforeEach(ExtensionContext context) throws Exception { - // TODO + public void afterEach(ExtensionContext context) throws Exception { + final ExtensionContext.Store store = context.getStore(NAMESPACE); + + final List> indices = (List>) store.get(INDICIES, List.class); + final ElasticsearchConnection connection = store.get(CONNECTION, ElasticsearchConnection.class); + + if (indices != null) { + cleanUp(connection, indices); + } + } + + private void cleanUp(@Nonnull ElasticsearchConnection connection, @Nonnull List> indices) { + for (SearchIndex i : indices) { + connection.getTransportClient().admin().indices().prepareDelete(i.getName()).get(); + } + + indices.clear(); } } diff --git a/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/SearchIndexFactory.java b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/SearchIndexFactory.java new file mode 100644 index 000000000..be4f04cbf --- /dev/null +++ b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/SearchIndexFactory.java @@ -0,0 +1,50 @@ +package com.linkedin.metadata.testing; + +import com.linkedin.data.template.RecordTemplate; +import java.util.concurrent.ExecutionException; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; +import org.elasticsearch.common.xcontent.XContentType; + + +/** + * Factory to create {@link SearchIndex} instances for testing. + */ +final class SearchIndexFactory { + private final ElasticsearchConnection _connection; + + SearchIndexFactory(@Nonnull ElasticsearchConnection connection) { + _connection = connection; + } + + /** + * Creates a search index to read / write the given document type for testing. + * + *

This will create an index on the Elasticsearch instance with a unique name. + * + * @param documentClass the document type + * @param name the name to use for the index + */ + public SearchIndex createIndex(@Nonnull Class documentClass, + @Nonnull String name, @Nullable String settingsJson, @Nullable String mappingsJson) { + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(name); + + if (settingsJson != null) { + createIndexRequest.settings(settingsJson, XContentType.JSON); + } + + if (mappingsJson != null) { + // TODO + createIndexRequest.mapping("doc", mappingsJson, XContentType.JSON); + } + + try { + _connection.getTransportClient().admin().indices().create(createIndexRequest).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + + return new SearchIndex<>(documentClass, _connection, name); + } +} diff --git a/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/annotations/SearchIndexType.java b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/annotations/SearchIndexType.java index 19aedd532..3138dc196 100644 --- a/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/annotations/SearchIndexType.java +++ b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/annotations/SearchIndexType.java @@ -1,6 +1,7 @@ package com.linkedin.metadata.testing.annotations; import com.linkedin.data.template.RecordTemplate; +import com.linkedin.metadata.testing.SearchIndex; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; @@ -9,9 +10,9 @@ /** - * Annotates the given {@link com.linkedin.metadata.testing.SearchIndex} field with the document type. + * Annotates the given {@link SearchIndex} field with the document type. * - *

Required annotation for {@link com.linkedin.metadata.testing.SearchIndex} instances in tests.

+ *

Required annotation for {@link SearchIndex} instances in tests.

*/ @Target(ElementType.FIELD) @Retention(RetentionPolicy.RUNTIME) @@ -19,7 +20,7 @@ /** * The search document class for this index. * - *

Used to create an instance of the {@link com.linkedin.metadata.testing.SearchIndex} during testing. + *

Used to create an instance of the {@link SearchIndex} during testing. */ @Nonnull Class value();