diff --git a/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java index 6ba870b7f..a4c5e1fc8 100644 --- a/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java +++ b/testing/elasticsearch-dao-integ-testing/src/main/java/com/linkedin/metadata/testing/ElasticsearchIntegrationTestExtension.java @@ -33,10 +33,26 @@ final class ElasticsearchIntegrationTestExtension private static final ExtensionContext.Namespace NAMESPACE = ExtensionContext.Namespace.create(ElasticsearchIntegrationTestExtension.class); - private final static String CONTAINER_FACTORY = "containerFactory"; - private final static String CONNECTION = "connection"; private final static String STATIC_INDICIES = "staticIndicies"; private final static String INDICIES = "indicies"; + private static ElasticsearchConnection _connection = null; + + private void startElasticSearch() throws Exception { + if (_connection != null) { + return; + } + + final ElasticsearchContainerFactory factory = getContainerFactory(); + _connection = factory.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + factory.close(); + } catch (Throwable throwable) { + throwable.printStackTrace(); + } + })); + } @Override public void beforeAll(ExtensionContext context) throws Exception { @@ -46,18 +62,14 @@ public void beforeAll(ExtensionContext context) throws Exception { final ExtensionContext.Store store = context.getStore(NAMESPACE); - final ElasticsearchContainerFactory factory = getContainerFactory(); - final ElasticsearchConnection connection = factory.start(); - - store.put(CONTAINER_FACTORY, factory); - store.put(CONNECTION, connection); + startElasticSearch(); final List fields = ReflectionUtils.findFields(testClass, field -> { return ReflectionUtils.isStatic(field) && ReflectionUtils.isPublic(field) && ReflectionUtils.isNotFinal(field) && SearchIndex.class.isAssignableFrom(field.getType()); }, ReflectionUtils.HierarchyTraversalMode.TOP_DOWN); - final SearchIndexFactory indexFactory = new SearchIndexFactory(connection); + final SearchIndexFactory indexFactory = new SearchIndexFactory(_connection); final List> indices = createIndices(indexFactory, context.getRequiredTestClass(), fields, fieldName -> String.format("%s_%s_%s", fieldName, testClass.getSimpleName(), System.currentTimeMillis())); store.put(STATIC_INDICIES, indices); @@ -143,14 +155,13 @@ private ElasticsearchContainerFactory getContainerFactory() throws Exception { @Override public void beforeEach(ExtensionContext context) throws Exception { final ExtensionContext.Store store = context.getStore(NAMESPACE); - final ElasticsearchConnection connection = store.get(CONNECTION, ElasticsearchConnection.class); final List fields = ReflectionUtils.findFields(context.getRequiredTestClass(), field -> { return ReflectionUtils.isNotStatic(field) && ReflectionUtils.isPublic(field) && ReflectionUtils.isNotFinal(field) && SearchIndex.class.isAssignableFrom(field.getType()); }, ReflectionUtils.HierarchyTraversalMode.TOP_DOWN); - final SearchIndexFactory indexFactory = new SearchIndexFactory(connection); + final SearchIndexFactory indexFactory = new SearchIndexFactory(_connection); final List> indices = createIndices(indexFactory, context.getRequiredTestInstance(), fields, fieldName -> String.format("%s_%s_%s_%s", fieldName, context.getRequiredTestMethod().getName(), context.getRequiredTestClass().getSimpleName(), System.currentTimeMillis())); @@ -163,9 +174,8 @@ public void afterAll(ExtensionContext context) throws Exception { final ExtensionContext.Store store = context.getStore(NAMESPACE); final List> indices = (List>) store.get(STATIC_INDICIES, List.class); - final ElasticsearchConnection connection = store.get(CONNECTION, ElasticsearchConnection.class); - cleanUp(connection, indices); + cleanUp(indices); // don't need to close the factory since it implements CloseableResource, junit will close it since it is in the // store @@ -177,16 +187,15 @@ public void afterEach(ExtensionContext context) throws Exception { final ExtensionContext.Store store = context.getStore(NAMESPACE); final List> indices = (List>) store.get(INDICIES, List.class); - final ElasticsearchConnection connection = store.get(CONNECTION, ElasticsearchConnection.class); if (indices != null) { - cleanUp(connection, indices); + cleanUp(indices); } } - private void cleanUp(@Nonnull ElasticsearchConnection connection, @Nonnull List> indices) { + private void cleanUp(@Nonnull List> indices) { for (SearchIndex i : indices) { - connection.getTransportClient().admin().indices().prepareDelete(i.getName()).get(); + _connection.getTransportClient().admin().indices().prepareDelete(i.getName()).get(); } indices.clear();