From df5e6a5c7109e4397b05fd02fe1568b35adcb1ce Mon Sep 17 00:00:00 2001 From: Vedant Mahabaleshwarkar Date: Tue, 5 Sep 2023 20:37:11 -0400 Subject: [PATCH 1/7] feat: Add per model metrics (#90) - Add `modelId` parameter to `logTimingMetricDuration` function in `Metrics.java`: - `modelmesh_cache_miss_milliseconds` - `modelmesh_loadmodel_milliseconds` - `modelmesh_unloadmodel_milliseconds` - `modelmesh_req_queue_delay_milliseconds` - `modelmesh_model_sizing_milliseconds` - `modelmesh_age_at_eviction_milliseconds` - Add `modelId` parameter to `logSizeEventMetric` function in `Metrics.java`: - `modelmesh_loading_queue_delay_milliseconds` - `modelmesh_loaded_model_size_bytes` - Add `modelId` and `vModelId` param to `logRequestMetrics` in `Metrics.java`: - `modelmesh_invoke_model_milliseconds` - `modelmesh_api_request_milliseconds` Closes #60 Signed-off-by: Vedant Mahabaleshwarkar Signed-off-by: Nick Hill Co-authored-by: Prashant Sharma Co-authored-by: Daniele Zonca Co-authored-by: Nick Hill --- .../com/ibm/watson/modelmesh/Metrics.java | 112 +++++++++++++----- .../com/ibm/watson/modelmesh/ModelMesh.java | 40 ++++--- .../ibm/watson/modelmesh/ModelMeshApi.java | 47 ++++++-- .../ibm/watson/modelmesh/VModelManager.java | 1 - .../watson/prometheus/SimpleCollector.java | 2 +- .../modelmesh/ModelMeshMetricsTest.java | 38 +++--- 6 files changed, 166 insertions(+), 74 deletions(-) diff --git a/src/main/java/com/ibm/watson/modelmesh/Metrics.java b/src/main/java/com/ibm/watson/modelmesh/Metrics.java index b246a5c3..7be788fe 100644 --- a/src/main/java/com/ibm/watson/modelmesh/Metrics.java +++ b/src/main/java/com/ibm/watson/modelmesh/Metrics.java @@ -16,6 +16,7 @@ package com.ibm.watson.modelmesh; +import com.google.common.base.Strings; import com.ibm.watson.prometheus.Counter; import com.ibm.watson.prometheus.Gauge; import com.ibm.watson.prometheus.Histogram; @@ -36,34 +37,39 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import java.lang.reflect.Array; import java.net.SocketAddress; import java.nio.channels.DatagramChannel; -import java.util.*; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import static com.ibm.watson.modelmesh.Metric.*; +import static com.ibm.watson.modelmesh.Metric.MetricType.*; import static com.ibm.watson.modelmesh.ModelMesh.M; import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_CUSTOM_ENV_VAR; -import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_METRICS_ENV_VAR; import static java.util.concurrent.TimeUnit.*; /** * */ interface Metrics extends AutoCloseable { + boolean isPerModelMetricsEnabled(); boolean isEnabled(); void logTimingMetricSince(Metric metric, long prevTime, boolean isNano); - void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano); + void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId); - void logSizeEventMetric(Metric metric, long value); + void logSizeEventMetric(Metric metric, long value, String modelId); void logGaugeMetric(Metric metric, long value); @@ -101,7 +107,7 @@ default void logInstanceStats(final InstanceRecord ir) { * @param respPayloadSize response payload size in bytes (or -1 if not applicable) */ void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize); + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId); default void registerGlobals() {} @@ -111,6 +117,11 @@ default void unregisterGlobals() {} default void close() {} Metrics NO_OP_METRICS = new Metrics() { + @Override + public boolean isPerModelMetricsEnabled() { + return false; + } + @Override public boolean isEnabled() { return false; @@ -120,10 +131,10 @@ public boolean isEnabled() { public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {} @Override - public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) {} + public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId){} @Override - public void logSizeEventMetric(Metric metric, long value) {} + public void logSizeEventMetric(Metric metric, long value, String modelId){} @Override public void logGaugeMetric(Metric metric, long value) {} @@ -136,7 +147,7 @@ public void logInstanceStats(InstanceRecord ir) {} @Override public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize) {} + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) {} }; final class PrometheusMetrics implements Metrics { @@ -154,12 +165,14 @@ final class PrometheusMetrics implements Metrics { private final CollectorRegistry registry; private final NettyServer metricServer; private final boolean shortNames; + private final boolean perModelMetricsEnabled; private final EnumMap metricsMap = new EnumMap<>(Metric.class); public PrometheusMetrics(Map params, Map infoMetricParams) throws Exception { int port = 2112; boolean shortNames = true; boolean https = true; + boolean perModelMetricsEnabled = true; String memMetrics = "all"; // default to all for (Entry ent : params.entrySet()) { switch (ent.getKey()) { @@ -170,6 +183,9 @@ public PrometheusMetrics(Map params, Map infoMet throw new Exception("Invalid metrics port: " + ent.getValue()); } break; + case "per_model_metrics": + perModelMetricsEnabled= "true".equalsIgnoreCase(ent.getValue()); + break; case "fq_names": shortNames = !"true".equalsIgnoreCase(ent.getValue()); break; @@ -188,6 +204,7 @@ public PrometheusMetrics(Map params, Map infoMet throw new Exception("Unrecognized metrics config parameter: " + ent.getKey()); } } + this.perModelMetricsEnabled = perModelMetricsEnabled; registry = new CollectorRegistry(); for (Metric m : Metric.values()) { @@ -220,10 +237,15 @@ public PrometheusMetrics(Map params, Map infoMet } if (m == API_REQUEST_TIME || m == API_REQUEST_COUNT || m == INVOKE_MODEL_TIME - || m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) { - builder.labelNames("method", "code"); + || m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) { + if (this.perModelMetricsEnabled) { + builder.labelNames("method", "code", "modelId", "vModelId"); + } else { + builder.labelNames("method", "code"); + } + } else if (this.perModelMetricsEnabled && m.type != GAUGE && m.type != COUNTER && m.type != COUNTER_WITH_HISTO) { + builder.labelNames("modelId", "vModelId"); } - Collector collector = builder.name(m.promName).help(m.description).create(); metricsMap.put(m, collector); if (!m.global) { @@ -330,6 +352,11 @@ public void close() { this.metricServer.close(); } + @Override + public boolean isPerModelMetricsEnabled() { + return perModelMetricsEnabled; + } + @Override public boolean isEnabled() { return true; @@ -342,13 +369,23 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) { } @Override - public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) { - ((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed); + public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) { + Histogram histogram = (Histogram) metricsMap.get(metric); + if (perModelMetricsEnabled && modelId != null) { + histogram.labels(modelId, "").observe(isNano ? elapsed / M : elapsed); + } else { + histogram.observe(isNano ? elapsed / M : elapsed); + } } @Override - public void logSizeEventMetric(Metric metric, long value) { - ((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier); + public void logSizeEventMetric(Metric metric, long value, String modelId) { + Histogram histogram = (Histogram) metricsMap.get(metric); + if (perModelMetricsEnabled) { + histogram.labels(modelId, "").observe(value * metric.newMultiplier); + } else { + histogram.observe(value * metric.newMultiplier); + } } @Override @@ -365,23 +402,37 @@ public void logCounterMetric(Metric metric) { @Override public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize) { + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) { final long elapsedMillis = elapsedNanos / M; final Histogram timingHisto = (Histogram) metricsMap .get(external ? API_REQUEST_TIME : INVOKE_MODEL_TIME); int idx = shortNames ? name.indexOf('/') : -1; - final String methodName = idx == -1 ? name : name.substring(idx + 1); - - timingHisto.labels(methodName, code.name()).observe(elapsedMillis); - + String methodName = idx == -1 ? name : name.substring(idx + 1); + if (perModelMetricsEnabled) { + modelId = Strings.nullToEmpty(modelId); + vModelId = Strings.nullToEmpty(vModelId); + } + if (perModelMetricsEnabled) { + timingHisto.labels(methodName, code.name(), modelId, vModelId).observe(elapsedMillis); + } else { + timingHisto.labels(methodName, code.name()).observe(elapsedMillis); + } if (reqPayloadSize != -1) { - ((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE)) - .labels(methodName, code.name()).observe(reqPayloadSize); + Histogram reqPayloadHisto = (Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE); + if (perModelMetricsEnabled) { + reqPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(reqPayloadSize); + } else { + reqPayloadHisto.labels(methodName, code.name()).observe(reqPayloadSize); + } } if (respPayloadSize != -1) { - ((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE)) - .labels(methodName, code.name()).observe(respPayloadSize); + Histogram respPayloadHisto = (Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE); + if (perModelMetricsEnabled) { + respPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(respPayloadSize); + } else { + respPayloadHisto.labels(methodName, code.name()).observe(respPayloadSize); + } } } @@ -437,6 +488,11 @@ protected StatsDSender createSender(Callable addressLookup, int q + (shortNames ? "short" : "fully-qualified") + " method names"); } + @Override + public boolean isPerModelMetricsEnabled() { + return false; + } + @Override public boolean isEnabled() { return true; @@ -454,12 +510,12 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) { } @Override - public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano) { + public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) { client.recordExecutionTime(name(metric), isNano ? elapsed / M : elapsed); } @Override - public void logSizeEventMetric(Metric metric, long value) { + public void logSizeEventMetric(Metric metric, long value, String modelId) { if (!legacy) { value *= metric.newMultiplier; } @@ -497,7 +553,7 @@ static String[] getOkTags(String method, boolean shortName) { @Override public void logRequestMetrics(boolean external, String name, long elapsedNanos, Code code, - int reqPayloadSize, int respPayloadSize) { + int reqPayloadSize, int respPayloadSize, String modelId, String vModelId) { final StatsDClient client = this.client; final long elapsedMillis = elapsedNanos / M; final String countName = name(external ? API_REQUEST_COUNT : INVOKE_MODEL_COUNT); diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index 9755df49..78c776b4 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -1966,7 +1966,7 @@ final synchronized boolean doRemove(final boolean evicted, // "unload" event if explicit unloading isn't enabled. // Otherwise, this gets recorded in a callback set in the // CacheEntry.unload(int) method - metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, 0L, false); + metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, 0L, false, modelId); metrics.logCounterMetric(Metric.UNLOAD_MODEL); } } @@ -2037,7 +2037,7 @@ public void onSuccess(Boolean reallyHappened) { //TODO probably only log if took longer than a certain time long tookMillis = msSince(beforeNanos); logger.info("Unload of " + modelId + " completed in " + tookMillis + "ms"); - metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, tookMillis, false); + metrics.logTimingMetricDuration(Metric.UNLOAD_MODEL_TIME, tookMillis, false, modelId); metrics.logCounterMetric(Metric.UNLOAD_MODEL); } // else considered trivially succeeded because the corresponding @@ -2158,7 +2158,7 @@ public final void run() { long queueStartTimeNanos = getAndResetLoadingQueueStartTimeNanos(); if (queueStartTimeNanos > 0) { long queueDelayMillis = (nanoTime() - queueStartTimeNanos) / M; - metrics.logSizeEventMetric(Metric.LOAD_MODEL_QUEUE_DELAY, queueDelayMillis); + metrics.logSizeEventMetric(Metric.LOAD_MODEL_QUEUE_DELAY, queueDelayMillis, modelId); // Only log if the priority value is "in the future" which indicates // that there is or were runtime requests waiting for this load. // Otherwise we don't care about arbitrary delays here @@ -2228,7 +2228,7 @@ public final void run() { loadingTimeStats(modelType).recordTime(tookMillis); logger.info("Load of model " + modelId + " type=" + modelType + " completed in " + tookMillis + "ms"); - metrics.logTimingMetricDuration(Metric.LOAD_MODEL_TIME, tookMillis, false); + metrics.logTimingMetricDuration(Metric.LOAD_MODEL_TIME, tookMillis, false, modelId); metrics.logCounterMetric(Metric.LOAD_MODEL); } catch (Throwable t) { loadFuture = null; @@ -2388,7 +2388,7 @@ protected final void complete(LoadedRuntime result, Throwable error) { if (size > 0) { long sizeBytes = size * UNIT_SIZE; logger.info("Model " + modelId + " size = " + size + " units" + ", ~" + mb(sizeBytes)); - metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes); + metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes, modelId); } else { try { long before = nanoTime(); @@ -2397,9 +2397,9 @@ protected final void complete(LoadedRuntime result, Throwable error) { long took = msSince(before), sizeBytes = size * UNIT_SIZE; logger.info("Model " + modelId + " size = " + size + " units" + ", ~" + mb(sizeBytes) + " sizing took " + took + "ms"); - metrics.logTimingMetricDuration(Metric.MODEL_SIZING_TIME, took, false); + metrics.logTimingMetricDuration(Metric.MODEL_SIZING_TIME, took, false, modelId); // this is actually a size (bytes), not a "time" - metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes); + metrics.logSizeEventMetric(Metric.LOADED_MODEL_SIZE, sizeBytes, modelId); } } catch (Exception e) { if (!isInterruption(e) && state == SIZING) { @@ -2722,7 +2722,7 @@ protected void beforeInvoke(int requestWeight) //noinspection ThrowFromFinallyBlock throw new ModelNotHereException(instanceId, modelId); } - metrics.logTimingMetricDuration(Metric.QUEUE_DELAY, tookMillis, false); + metrics.logTimingMetricDuration(Metric.QUEUE_DELAY, tookMillis, false, modelId); } } } @@ -2901,7 +2901,7 @@ public void onEviction(String key, CacheEntry ce, long lastUsed) { logger.info("Evicted " + (failed ? "failed model record" : "model") + " " + key + " from local cache, last used " + readableTime(millisSinceLastUsed) + " ago (" + lastUsed + "ms), invoked " + ce.getTotalInvocationCount() + " times"); - metrics.logTimingMetricDuration(Metric.AGE_AT_EVICTION, millisSinceLastUsed, false); + metrics.logTimingMetricDuration(Metric.AGE_AT_EVICTION, millisSinceLastUsed, false, ce.modelId); metrics.logCounterMetric(Metric.EVICT_MODEL); } @@ -3315,6 +3315,7 @@ protected Map getMap(Object[] arr) { static final String KNOWN_SIZE_CXT_KEY = "tas.known_size"; static final String UNBALANCED_KEY = "mmesh.unbalanced"; static final String DEST_INST_ID_KEY = "tas.dest_iid"; + static final String VMODEL_ID = "vmodelid"; // these are the possible values for the tas.internal context parameter // it won't be set on requests from outside of the cluster, and will @@ -3430,6 +3431,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me } final String tasInternal = contextMap.get(TAS_INTERNAL_CXT_KEY); + final String vModelId = contextMap.getOrDefault(VMODEL_ID, ""); // Set the external request flag if it's not a tasInternal call or if // tasInternal == INTERNAL_REQ. The latter is a new ensureLoaded // invocation originating from within the cluster. @@ -3502,7 +3504,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me throw new ModelNotHereException(instanceId, modelId); } try { - return invokeLocalModel(ce, method, args, modelId); + return invokeLocalModel(ce, method, args, vModelId); } catch (ModelLoadException mle) { mr = registry.get(modelId); if (mr == null || !mr.loadFailedInInstance(instanceId)) { @@ -3716,7 +3718,7 @@ protected Object invokeModel(final String modelId, final Method method, final Me localInvokesInFlight.incrementAndGet(); } try { - Object result = invokeLocalModel(cacheEntry, method, args, modelId); + Object result = invokeLocalModel(cacheEntry, method, args, vModelId); return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result; } finally { if (!favourSelfForHits) { @@ -3936,7 +3938,7 @@ else if (mr.getInstanceIds().containsKey(instanceId)) { // invoke model try { - Object result = invokeLocalModel(cacheEntry, method, args, modelId); + Object result = invokeLocalModel(cacheEntry, method, args, vModelId); return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result; } catch (ModelNotHereException e) { if (loadTargetFilter != null) loadTargetFilter.remove(instanceId); @@ -3991,7 +3993,7 @@ else if (mr.getInstanceIds().containsKey(instanceId)) { if (methodStartNanos > 0L && metrics.isEnabled()) { // only logged here in non-grpc (legacy) mode metrics.logRequestMetrics(true, getRequestMethodName(method, args), - nanoTime() - methodStartNanos, metricStatusCode, -1, -1); + nanoTime() - methodStartNanos, metricStatusCode, -1, -1, modelId, vModelId); } curThread.setName(threadNameBefore); } @@ -4403,17 +4405,17 @@ protected Object invokeRemoteModel(BaseModelMeshService.Iface client, Method met return remoteMeth.invoke(client, ObjectArrays.concat(modelId, args)); } - protected Object invokeLocalModel(CacheEntry ce, Method method, Object[] args, String modelId) + protected Object invokeLocalModel(CacheEntry ce, Method method, Object[] args, String vModelId) throws InterruptedException, TException { - Object result = invokeLocalModel(ce, method, args); + final Object result = _invokeLocalModel(ce, method, args, vModelId); // if this is an ensure-loaded request, check-for and trigger a "chained" load if necessary if (method == null) { - triggerChainedLoadIfNecessary(modelId, result, args, ce.getWeight(), null); + triggerChainedLoadIfNecessary(ce.modelId, result, args, ce.getWeight(), null); } return result; } - private Object invokeLocalModel(CacheEntry ce, Method method, Object[] args) + private Object _invokeLocalModel(CacheEntry ce, Method method, Object[] args, String vModelId) throws InterruptedException, TException { if (method == null) { @@ -4450,7 +4452,7 @@ private Object invokeLocalModel(CacheEntry ce, Method method, Object[] args) long delayMillis = msSince(beforeNanos); logger.info("Cache miss for model invocation, held up " + delayMillis + "ms"); metrics.logCounterMetric(Metric.CACHE_MISS); - metrics.logTimingMetricDuration(Metric.CACHE_MISS_DELAY, delayMillis, false); + metrics.logTimingMetricDuration(Metric.CACHE_MISS_DELAY, delayMillis, false, ce.modelId); } } } else { @@ -4528,7 +4530,7 @@ private Object invokeLocalModel(CacheEntry ce, Method method, Object[] args) ce.afterInvoke(weight, tookNanos); if (code != null && metrics.isEnabled()) { metrics.logRequestMetrics(false, getRequestMethodName(method, args), - tookNanos, code, -1, -1); + tookNanos, code, -1, -1, ce.modelId, vModelId); } } } diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java index 6f3f8202..715c0efe 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java @@ -30,6 +30,7 @@ import com.ibm.watson.litelinks.server.ReleaseAfterResponse; import com.ibm.watson.litelinks.server.ServerRequestThread; import com.ibm.watson.modelmesh.DataplaneApiConfig.RpcConfig; +import com.ibm.watson.modelmesh.GrpcSupport.InterruptingListener; import com.ibm.watson.modelmesh.ModelMesh.ExtendedStatusInfo; import com.ibm.watson.modelmesh.api.DeleteVModelRequest; import com.ibm.watson.modelmesh.api.DeleteVModelResponse; @@ -68,6 +69,7 @@ import io.grpc.ServerInterceptors; import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; +import io.grpc.Status.Code; import io.grpc.StatusException; import io.grpc.StatusRuntimeException; import io.grpc.netty.GrpcSslContexts; @@ -85,6 +87,7 @@ import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.FastThreadLocalThread; import org.apache.thrift.TException; import org.slf4j.Logger; @@ -344,6 +347,10 @@ protected static void setUnbalancedLitelinksContextParam() { ThreadContext.addContextEntry(ModelMesh.UNBALANCED_KEY, "true"); // unbalanced } + protected static void setVModelIdLiteLinksContextParam(String vModelId) { + ThreadContext.addContextEntry(ModelMesh.VMODEL_ID, vModelId); + } + // ----------------- concrete model management methods @Override @@ -427,6 +434,9 @@ public void ensureLoaded(EnsureLoadedRequest request, StreamObserver resolvedModelId = new FastThreadLocal<>(); + // Returned ModelResponse will be released once the request thread exits so // must be retained before transferring. // non-private to avoid synthetic method access @@ -441,9 +451,13 @@ ModelResponse callModel(String modelId, boolean isVModel, String methodName, Str } String vModelId = modelId; modelId = null; + if (delegate.metrics.isPerModelMetricsEnabled()) { + setVModelIdLiteLinksContextParam(vModelId); + } boolean first = true; while (true) { modelId = vmm().resolveVModelId(vModelId, modelId); + resolvedModelId.set(modelId); if (unbalanced) { setUnbalancedLitelinksContextParam(); } @@ -542,7 +556,7 @@ protected static void respondAndComplete(StreamObserver response, } protected static io.grpc.Status toStatus(Exception e) { - io.grpc.Status s = null; + io.grpc.Status s; String msg = e.getMessage(); if (e instanceof ModelNotFoundException) { return MODEL_NOT_FOUND_STATUS; @@ -655,7 +669,7 @@ public Listener startCall(ServerCall call, Metadata h call.request(2); // request 2 to force failure if streaming method - return new Listener() { + return new Listener<>() { ByteBuf reqMessage; boolean canInvoke = true; Iterable modelIds = mids.modelIds; @@ -707,7 +721,8 @@ public void onHalfClose() { int respReaderIndex = 0; io.grpc.Status status = INTERNAL; - String modelId = null; + String resolvedModelId = null; + String vModelId = null; String requestId = null; ModelResponse response = null; try (InterruptingListener cancelListener = newInterruptingListener()) { @@ -721,16 +736,28 @@ public void onHalfClose() { String balancedMetaVal = headers.get(BALANCED_META_KEY); Iterator midIt = modelIds.iterator(); // guaranteed at least one - modelId = validateModelId(midIt.next(), isVModel); + String modelOrVModelId = validateModelId(midIt.next(), isVModel); if (!midIt.hasNext()) { // single model case (most common) - response = callModel(modelId, isVModel, methodName, - balancedMetaVal, headers, reqMessage).retain(); + if (isVModel) { + ModelMeshApi.resolvedModelId.set(null); + } + try { + response = callModel(modelOrVModelId, isVModel, methodName, + balancedMetaVal, headers, reqMessage).retain(); + } finally { + if (isVModel) { + vModelId = modelOrVModelId; + resolvedModelId = ModelMeshApi.resolvedModelId.getIfExists(); + } else { + resolvedModelId = modelOrVModelId; + } + } } else { // multi-model case (specialized) boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY)); List idList = new ArrayList<>(); - idList.add(modelId); + idList.add(modelOrVModelId); while (midIt.hasNext()) { idList.add(validateModelId(midIt.next(), isVModel)); } @@ -740,7 +767,7 @@ public void onHalfClose() { } finally { if (payloadProcessor != null) { processPayload(reqMessage.readerIndex(reqReaderIndex), - requestId, modelId, methodName, headers, null, true); + requestId, resolvedModelId, methodName, headers, null, true); } else { releaseReqMessage(); } @@ -776,7 +803,7 @@ public void onHalfClose() { data = response.data.readerIndex(respReaderIndex); metadata = response.metadata; } - processPayload(data, requestId, modelId, methodName, metadata, status, releaseResponse); + processPayload(data, requestId, resolvedModelId, methodName, metadata, status, releaseResponse); } else if (releaseResponse && response != null) { response.release(); } @@ -787,7 +814,7 @@ public void onHalfClose() { Metrics metrics = delegate.metrics; if (metrics.isEnabled()) { metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos, - status.getCode(), reqSize, respSize); + status.getCode(), reqSize, respSize, resolvedModelId, vModelId); } } } diff --git a/src/main/java/com/ibm/watson/modelmesh/VModelManager.java b/src/main/java/com/ibm/watson/modelmesh/VModelManager.java index d2706a16..7ad5da8a 100644 --- a/src/main/java/com/ibm/watson/modelmesh/VModelManager.java +++ b/src/main/java/com/ibm/watson/modelmesh/VModelManager.java @@ -27,7 +27,6 @@ import com.ibm.watson.kvutils.KVTable.Helper.TableTxn; import com.ibm.watson.kvutils.KVTable.TableView; import com.ibm.watson.kvutils.factory.KVUtilsFactory; -import com.ibm.watson.litelinks.ThreadContext; import com.ibm.watson.litelinks.ThreadPoolHelper; import com.ibm.watson.modelmesh.GrpcSupport.InterruptingListener; import com.ibm.watson.modelmesh.api.ModelInfo; diff --git a/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java b/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java index ffca070b..c7b25c1f 100644 --- a/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java +++ b/src/main/java/com/ibm/watson/prometheus/SimpleCollector.java @@ -161,7 +161,7 @@ private static int nextIdx(int i, int len) { private void validateCount(int count) { if (count != labelCount) { - throw new IllegalArgumentException("Incorrect number of labels."); + throw new IllegalArgumentException("Incorrect number of labels. Expected: " + labelCount + ", got: " + count); } } diff --git a/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java b/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java index dc6ee35e..a78cef1c 100644 --- a/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/ModelMeshMetricsTest.java @@ -32,6 +32,7 @@ import io.grpc.ManagedChannel; import io.grpc.netty.NettyChannelBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import javax.net.ssl.SSLContext; @@ -76,10 +77,11 @@ protected int requestCount() { @Override protected Map extraEnvVars() { - return ImmutableMap.of("MM_METRICS", "prometheus:port=" + METRICS_PORT + ";scheme=" + SCHEME); + return ImmutableMap.of("MM_METRICS", "prometheus:port=" + METRICS_PORT + ";scheme=" + SCHEME + + ";per_model_metrics=true"); } - @Test + @BeforeAll public void metricsTest() throws Exception { ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 9000).usePlaintext().build(); @@ -151,11 +153,11 @@ public void metricsTest() throws Exception { } } - public void verifyMetrics() throws Exception { + protected Map prepareMetrics() throws Exception { // Insecure trust manager - skip TLS verification SSLContext sslContext = SSLContext.getInstance("TLS"); sslContext.init(null, InsecureTrustManagerFactory.INSTANCE.getTrustManagers(), null); - + HttpClient client = HttpClient.newBuilder().sslContext(sslContext).build(); HttpRequest metricsRequest = HttpRequest.newBuilder() .uri(URI.create(SCHEME + "://localhost:" + METRICS_PORT + "/metrics")).build(); @@ -172,29 +174,35 @@ public void verifyMetrics() throws Exception { .filter(Matcher::matches) .collect(Collectors.toMap(m -> m.group(1), m -> Double.parseDouble(m.group(2)))); + return metrics; + } + + @Test + public void verifyMetrics() throws Exception { + // Insecure trust manager - skip TLS verification + Map metrics = prepareMetrics(); + System.out.println(metrics.size() + " metrics scraped"); // Spot check some expected metrics and values // External response time should all be < 2000ms (includes cache hit loading time) - assertEquals(40.0, metrics.get("modelmesh_api_request_milliseconds_bucket{method=\"predict\",code=\"OK\",le=\"2000.0\",}")); + assertEquals(40.0, metrics.get("modelmesh_api_request_milliseconds_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"2000.0\",}")); // External response time should all be < 200ms (includes cache hit loading time) - assertEquals(40.0, metrics.get("modelmesh_invoke_model_milliseconds_bucket{method=\"predict\",code=\"OK\",le=\"200.0\",}")); + assertEquals(40.0, + metrics.get("modelmesh_invoke_model_milliseconds_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"120000.0\",}")); // Simulated model sizing time is < 200ms - assertEquals(1.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{le=\"200.0\",}")); + assertEquals(1.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{modelId=\"myModel\",vModelId=\"\",le=\"60000.0\",}")); // Simulated model sizing time is > 50ms - assertEquals(0.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{le=\"50.0\",}")); + assertEquals(0.0, metrics.get("modelmesh_model_sizing_milliseconds_bucket{modelId=\"myModel\",vModelId=\"\",le=\"50.0\",}")); // Simulated model size is between 64MiB and 256MiB - assertEquals(0.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{le=\"6.7108864E7\",}")); - assertEquals(1.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{le=\"2.68435456E8\",}")); + assertEquals(0.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{modelId=\"myModel\",vModelId=\"\",le=\"6.7108864E7\",}")); + assertEquals(1.0, metrics.get("modelmesh_loaded_model_size_bytes_bucket{modelId=\"myModel\",vModelId=\"\",le=\"2.68435456E8\",}")); // One model is loaded - assertEquals(1.0, metrics.get("modelmesh_models_loaded_total")); assertEquals(1.0, metrics.get("modelmesh_instance_models_total")); // Histogram counts should reflect the two payload sizes (30 small, 10 large) - assertEquals(30.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"128.0\",}")); - assertEquals(40.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"2097152.0\",}")); - assertEquals(30.0, metrics.get("modelmesh_response_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"128.0\",}")); - assertEquals(40.0, metrics.get("modelmesh_response_size_bytes_bucket{method=\"predict\",code=\"OK\",le=\"2097152.0\",}")); + assertEquals(30.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"128.0\",}")); + assertEquals(40.0, metrics.get("modelmesh_request_size_bytes_bucket{method=\"predict\",code=\"OK\",modelId=\"myModel\",vModelId=\"\",le=\"2097152.0\",}")); // Memory metrics assertTrue(metrics.containsKey("netty_pool_mem_allocated_bytes{area=\"direct\",}")); From 811d9583b885936c670c59262303dbf1c4f0bf0f Mon Sep 17 00:00:00 2001 From: cezhang Date: Fri, 8 Sep 2023 09:58:22 +0800 Subject: [PATCH 2/7] docs: fix typo (#103) Signed-off-by: cezhang --- config/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/README.md b/config/README.md index 6edeb5cc..1c0cf76c 100644 --- a/config/README.md +++ b/config/README.md @@ -8,7 +8,7 @@ The `examples` directory contains example Kustomization overlays to demonstrate - `custom-example` is an example of an overlay to deploy model-mesh with a custom model-serving runtime image - `custom-example-uds` extends `custom-example` to use a unix domain socket for intra-pod communication -- `type-constraints-example` is an example of a heterogeneous model-mesh deployment comprising two kubernetes Deployments with a single Service. It employs type constraints to control assignments of models to pod subsets based on laebels. +- `type-constraints-example` is an example of a heterogeneous model-mesh deployment comprising two kubernetes Deployments with a single Service. It employs type constraints to control assignments of models to pod subsets based on labels. The following patches are provided in `base/patches` and can be selectively included/modified in your custom overlay: From de7836a132c1b48211a206f5de575cf61c5d4cfe Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Thu, 28 Sep 2023 14:30:31 -0700 Subject: [PATCH 3/7] chore: Add new approvers to OWNERS file (#109) Add new approvers to OWNERS file following KServe community process. Closes kserve/community#6 Approved by @njhill --- Signed-off-by: Christian Kadner --- OWNERS | 2 ++ 1 file changed, 2 insertions(+) diff --git a/OWNERS b/OWNERS index 38a32051..84f618b6 100644 --- a/OWNERS +++ b/OWNERS @@ -1,6 +1,8 @@ approvers: - njhill - tjohnson31415 + - ckadner + - rafvasq reviewers: - njhill - tjohnson31415 From 757b91ed4a4bf91edd4b9d21e7fb5dc1c7ec96c5 Mon Sep 17 00:00:00 2001 From: Sviatoslav Kokurin Date: Sat, 30 Sep 2023 00:50:30 +0300 Subject: [PATCH 4/7] fix: Fix not reading limits when cgroup v2 used (#107) Docker image was not picking up memory limits when cgroup v2 was used. This lead to incorrect java heap memory limits that lead to crashes due to OOMKilled as described in #106. Add extra check for limits in cgroup v2 file for the image to correctly read memory limits with both cgroup v1 and cgroup v2. Resolves #106 --- Signed-off-by: funbiscuit --- src/main/scripts/start.sh | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/main/scripts/start.sh b/src/main/scripts/start.sh index a9e97b98..c5fe44eb 100644 --- a/src/main/scripts/start.sh +++ b/src/main/scripts/start.sh @@ -309,11 +309,15 @@ fi echo $$ > ${ANCHOR_FILE} if [ "$MEM_LIMIT_MB" = "" ]; then - DOCKER_LIM_FILE="/sys/fs/cgroup/memory/memory.limit_in_bytes" - - if [ -e "${DOCKER_LIM_FILE}" ]; then - MEM_LIMIT_MB=$(($(cat ${DOCKER_LIM_FILE})/1024/1024)) - echo "Using process mem limit of ${MEM_LIMIT_MB}MiB from ${DOCKER_LIM_FILE}" + CGROUP_V1_LIM_FILE="/sys/fs/cgroup/memory/memory.limit_in_bytes" + CGROUP_V2_LIM_FILE="/sys/fs/cgroup/memory.max" + + if [ -e "${CGROUP_V1_LIM_FILE}" ]; then + MEM_LIMIT_MB=$(($(cat ${CGROUP_V1_LIM_FILE})/1024/1024)) + echo "Using process mem limit of ${MEM_LIMIT_MB}MiB from ${CGROUP_V1_LIM_FILE}" + elif [ -e "${CGROUP_V2_LIM_FILE}" ]; then + MEM_LIMIT_MB=$(($(cat ${CGROUP_V2_LIM_FILE})/1024/1024)) + echo "Using process mem limit of ${MEM_LIMIT_MB}MiB from ${CGROUP_V2_LIM_FILE}" else MEM_LIMIT_MB="1536" echo "No process mem limit provided or found, defaulting to ${MEM_LIMIT_MB}MiB" From 45dde9d8090275dfb8fcffa5fdffd9a42ff46ba2 Mon Sep 17 00:00:00 2001 From: heyselbi Date: Fri, 6 Oct 2023 14:01:21 -0400 Subject: [PATCH 5/7] Revert "chore: Add new approvers to OWNERS file (#109)" This reverts commit de7836a132c1b48211a206f5de575cf61c5d4cfe. --- OWNERS | 2 -- 1 file changed, 2 deletions(-) diff --git a/OWNERS b/OWNERS index 84f618b6..38a32051 100644 --- a/OWNERS +++ b/OWNERS @@ -1,8 +1,6 @@ approvers: - njhill - tjohnson31415 - - ckadner - - rafvasq reviewers: - njhill - tjohnson31415 From c41babbc776512e9bee089f6003b5f2ad254c099 Mon Sep 17 00:00:00 2001 From: Spolti Date: Fri, 6 Oct 2023 16:30:46 -0300 Subject: [PATCH 6/7] add new members to the OWNERS file Signed-off-by: Spolti --- OWNERS | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/OWNERS b/OWNERS index 592faa9a..95f2cce3 100644 --- a/OWNERS +++ b/OWNERS @@ -1,18 +1,24 @@ approvers: - anishasthana - danielezonca + - davidesalerno - heyselbi - israel-hdez - Jooho + - rpancham + - spolti - vaibhavjainwiz - VedantMahabaleshwarkar - Xaenalt reviewers: - anishasthana - danielezonca + - davidesalerno - heyselbi - israel-hdez - Jooho + - rpancham + - spolti - vaibhavjainwiz - VedantMahabaleshwarkar - Xaenalt From cde3da1d27518d221f3ba4ebd2b4e588c3270f7c Mon Sep 17 00:00:00 2001 From: Spolti Date: Wed, 18 Oct 2023 15:00:44 -0300 Subject: [PATCH 7/7] Bump netty - CVE-2023-44487 Signed-off-by: Spolti --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 489fc9d6..aa1997df 100644 --- a/pom.xml +++ b/pom.xml @@ -58,7 +58,7 @@ ${env.BUILD_TAG} 1.57.2 - 4.1.96.Final + 4.1.100.Final 1.7.2 0.5.1 0.0.22