From e85651dcdb997ae311f193136ef2653c5057da82 Mon Sep 17 00:00:00 2001 From: "Dmitry A. Grechka" Date: Mon, 29 Aug 2022 20:03:19 +0300 Subject: [PATCH] Knn based image similarity search (#4) * supporting knn image similarity search --- CardIndexRestAPI/CardIndexRestAPI.csproj | 4 + .../Controllers/SolrProxyController.cs | 79 ++++++++++++++----- CardIndexRestAPI/DataSchema/Requests.cs | 4 + CardIndexRestAPI/Dockerfile | 2 +- CardIndexRestAPI/ISolrSearchConfig.cs | 2 +- CardIndexRestAPI/Program.cs | 6 +- CardIndexRestAPI/StaticSolrSearchConfig.cs | 6 +- 7 files changed, 76 insertions(+), 27 deletions(-) diff --git a/CardIndexRestAPI/CardIndexRestAPI.csproj b/CardIndexRestAPI/CardIndexRestAPI.csproj index 30d62ca..df404d6 100644 --- a/CardIndexRestAPI/CardIndexRestAPI.csproj +++ b/CardIndexRestAPI/CardIndexRestAPI.csproj @@ -12,4 +12,8 @@ + + + + diff --git a/CardIndexRestAPI/Controllers/SolrProxyController.cs b/CardIndexRestAPI/Controllers/SolrProxyController.cs index 35f6488..3513f61 100644 --- a/CardIndexRestAPI/Controllers/SolrProxyController.cs +++ b/CardIndexRestAPI/Controllers/SolrProxyController.cs @@ -11,6 +11,7 @@ using System.Net.Http; using System.Web; using static CardIndexRestAPI.DataSchema.Requests; +using System.Numerics; namespace SolrAPI.Controllers { @@ -41,7 +42,8 @@ public SolrProxyController(ISolrSearchConfig solrAddress) /// Where to write proxied data /// private static async Task ProxyHttpPost(string requestID, string requstURL, HttpContent? content, HttpResponse response) { - HttpClient client = new HttpClient(); + HttpClient client = new HttpClient(); + client.Timeout = TimeSpan.FromMinutes(20); client.DefaultRequestHeaders.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("application/json")); Trace.TraceInformation($"{requestID}: Issung request {requstURL}."); @@ -69,6 +71,7 @@ private static async Task ProxyHttpPost(string requestID, string requstURL, Http await outStream.DisposeAsync(); } + /* [EnableCors] [HttpPost("MatchedCardsSearch")] public async Task MatchedCardsSearch([FromBody]GetMatchesRequest request) @@ -143,6 +146,8 @@ public async Task MatchedCardsSearch([FromBody]GetMatchesRequest request) } } + */ + [EnableCors] [HttpPost("MatchedImagesSearch")] public async Task MatchedImagesSearch([FromBody] GetMatchesRequest request) @@ -151,8 +156,15 @@ public async Task MatchedImagesSearch([FromBody] GetMatchesRequest request) Trace.TraceInformation($"{requestHash}: Got request."); try { - string featureDims = String.Join(",", Enumerable.Range(0, request.Features.Length).Select(idx => $"{request.FeaturesIdent}_{idx}_d")); - string featuresTargetVal = String.Join(',', request.Features); + + var features = request.Features.ToArray(); //.Take(900).ToArray(); + double norm = Math.Sqrt(features.Sum(x => x * x)); + features = features.Select(x => x / norm).ToArray(); + + Trace.TraceInformation($"Feature length is {features.Length}"); + string featureDims = String.Join(",", Enumerable.Range(0, features.Length).Select(idx => $"{request.FeaturesIdent}_{idx}_d")); + //string featuresTargetVal = String.Join(',', features.Select(d => Math.Round(d,3))); + string featuresTargetVal = String.Join(',', features); DateTime shortTermSearchStart = request.EventType switch { @@ -183,26 +195,57 @@ public async Task MatchedImagesSearch([FromBody] GetMatchesRequest request) string longTermSpaceSpec = $"{{!geofilt sfield=location pt={request.Lat},{request.Lon} d={this.searchConfig.LongTermSearchRadiusKm}}}"; string longTermSearchTerm = $"{longTermTimeSpec} AND {longTermSpaceSpec}"; - string typeSearchTerm = request.EventType switch + List additionalFilter = new List(); + if (request.FilterFar ?? false) + additionalFilter.Add($"({longTermSearchTerm})"); + if (request.FilterLongAgo ?? false) + additionalFilter.Add($"({shortTermSearchTerm})"); + string additionalFiltersStr = String.Join(" OR ", additionalFilter); + + string typeFilterTerm = request.EventType switch // this one inverts the specification { "Found" => "card_type:Lost", "Lost" => "card_type:Found", _ => throw new ArgumentException($"Unknown EventType: {request.EventType}") }; + string animalFilterTerm = request.Animal switch + { + "Cat" => "animal:Cat", + "Dog" => "animal:Dog", + _ => throw new ArgumentException($"Unknown Animal: {request.Animal}") + }; - string solrFindLostRequest = - $"top(n={this.searchConfig.MaxReturnCount},having(select(search({this.searchConfig.ImagesCollectionName},q=\"animal:{request.Animal} AND {typeSearchTerm} AND (({shortTermSearchTerm})OR({longTermSearchTerm}))\",fl=\"id, event_time, {featureDims}\",sort=\"event_time asc\",qt=\"/export\"),id,cosineSimilarity(array({featureDims}), array({featuresTargetVal})) as similarity), gt(similarity, {this.searchConfig.SimilarityThreshold})),sort=\"similarity desc\")"; - //Trace.TraceInformation($"{requestHash}: Got request. Issuing: {solrFindLostRequest}"); - - //solrFindLostRequest = "top(n=100,select(search(kashtankaimages,q=\"animal:Cat AND card_type:Lost AND ((event_time:[2019-06-19T21:00:00.0000000Z TO 2019-08-02T21:00:00.0000000Z] AND {!geofilt sfield=location pt=56.273015,43.93563 d=1000})OR(event_time:[ * TO 2019-08-02T21:00:00.0000000Z] AND {!geofilt sfield=location pt=56.273015,43.93563 d=20}))\",fl=\"id, event_time\",sort=\"event_time asc\",qt=\"/export\"),id,),sort=\"similarity desc\")"; + string cardTypeAndAnimalFilter = $"{typeFilterTerm} AND {animalFilterTerm}"; + + // example from docs: https://solr.apache.org/guide/solr/latest/query-guide/dense-vector-search.html#query-time + // { !knn f = vector topK = 10}[1.0, 2.0, 3.0, 4.0] + + const string embeddingName = "calvin_zhirui_embedding"; + + string similaritySearchExpr = $"{{!knn f={embeddingName} topK={searchConfig.SimilarityKnnTopK}}}[{featuresTargetVal}]"; + + + Trace.TraceInformation($"sim search expr: ${similaritySearchExpr}"); + + List> requestParams = new List>(new KeyValuePair[] { + //new KeyValuePair("expr",solrFindLostRequest) + new KeyValuePair("q",similaritySearchExpr), + new KeyValuePair("fl",$"id,{embeddingName}"), + new KeyValuePair("fq",cardTypeAndAnimalFilter), + new KeyValuePair("rows",searchConfig.MaxReturnCount.ToString()) + }); + + if (!string.IsNullOrEmpty(additionalFiltersStr)) + requestParams.Add(new KeyValuePair("fq", additionalFiltersStr)); // To avoid "URL is too long" we pass the request inside POST body - FormUrlEncodedContent requestContent = new FormUrlEncodedContent(new KeyValuePair[] { - new KeyValuePair("expr",solrFindLostRequest) - }); ; + FormUrlEncodedContent requestContent = new FormUrlEncodedContent(requestParams); + + + //return NotFound(); - await ProxyHttpPost(requestHash, this.solrImagesStreamingExpressionsURL, requestContent, Response); + await ProxyHttpPost(requestHash, this.solrImagesSelectExpressionsURL, requestContent, Response); Trace.TraceInformation($"{requestHash}: Transmitted successfully"); @@ -241,13 +284,11 @@ public async Task LatestCards([FromQuery]int maxCardsCount=10, [FromQuery] strin if (!string.IsNullOrEmpty(typeConstraint)) { requestParams.Add("fq", typeConstraint); } - - string queryStr = - string.Join('&', requestParams.Select(kvp => $"{HttpUtility.UrlEncode(kvp.Key)}={HttpUtility.UrlEncode(kvp.Value)}")); - - string finalURL = $"{this.solrCardsSelectExpressionsURL}?{queryStr}"; + + FormUrlEncodedContent requestContent = new FormUrlEncodedContent(requestParams); + try { - await ProxyHttpPost("latest cards request", finalURL,null, Response); + await ProxyHttpPost("latest cards request", this.solrCardsSelectExpressionsURL, requestContent, Response); } catch (Exception err) { diff --git a/CardIndexRestAPI/DataSchema/Requests.cs b/CardIndexRestAPI/DataSchema/Requests.cs index d28550e..03ac89b 100644 --- a/CardIndexRestAPI/DataSchema/Requests.cs +++ b/CardIndexRestAPI/DataSchema/Requests.cs @@ -16,6 +16,8 @@ public class GetMatchesRequest { public double[] Features { get; set; } public string FeaturesIdent { get; set; } + public bool? FilterFar { get; set; } + public bool? FilterLongAgo { get; set; } public override int GetHashCode() { @@ -24,6 +26,8 @@ public override int GetHashCode() (this.Animal?.GetHashCode() ?? 0) ^ this.EventTime.GetHashCode() ^ this.EventType.GetHashCode() ^ + (this.FilterFar?.GetHashCode() ?? 0) ^ + (this.FilterLongAgo?.GetHashCode() ?? 0) ^ this.Features.Select(f => f.GetHashCode()).Aggregate(0, (acc, elem) => acc ^ elem) ^ this.Features.Length.GetHashCode() ^ (this.FeaturesIdent?.GetHashCode() ?? 0); diff --git a/CardIndexRestAPI/Dockerfile b/CardIndexRestAPI/Dockerfile index 37dbd73..317bb80 100644 --- a/CardIndexRestAPI/Dockerfile +++ b/CardIndexRestAPI/Dockerfile @@ -27,7 +27,7 @@ ENV LONG_TERM_SEARCH_RADIUS_KM=20.0 ENV SHORT_TERM_SEARCH_RADIUS_KM=200.0 ENV SHORT_TERM_LENGTH_DAYS=30 ENV REVERSE_TIME_GAP_LENGTH_DAYS=14 -ENV SIMILARITY_THRESHOLD=0.95 +ENV SIMILARITY_KNN_TOP_K=200 COPY --from=publish /app/publish . ENTRYPOINT ["dotnet", "CardIndexRestAPI.dll"] \ No newline at end of file diff --git a/CardIndexRestAPI/ISolrSearchConfig.cs b/CardIndexRestAPI/ISolrSearchConfig.cs index 9d61598..8da799e 100644 --- a/CardIndexRestAPI/ISolrSearchConfig.cs +++ b/CardIndexRestAPI/ISolrSearchConfig.cs @@ -15,6 +15,6 @@ public interface ISolrSearchConfig public double ShortTermSearchRadiusKm { get; } public TimeSpan ShortTermLength { get; } public TimeSpan ReverseTimeGapLength { get; } - public double SimilarityThreshold { get; } + public int SimilarityKnnTopK { get;} } } diff --git a/CardIndexRestAPI/Program.cs b/CardIndexRestAPI/Program.cs index edfefac..50f875d 100644 --- a/CardIndexRestAPI/Program.cs +++ b/CardIndexRestAPI/Program.cs @@ -55,8 +55,8 @@ public static void Main(string[] args) TimeSpan reverseTimeGapLength = TimeSpan.FromDays(int.Parse(Environment.GetEnvironmentVariable("REVERSE_TIME_GAP_LENGTH_DAYS") ?? "14")); Trace.TraceInformation($"REVERSE_TIME_GAP_LENGTH_DAYS: {reverseTimeGapLength}"); - double similarityThreshold = double.Parse(Environment.GetEnvironmentVariable("SIMILARITY_THRESHOLD") ?? "0.1"); - Trace.TraceInformation($"SIMILARITY_THRESHOLD: {similarityThreshold}"); + int similarityKnnTopK = int.Parse(Environment.GetEnvironmentVariable("SIMILARITY_KNN_TOP_K") ?? "200"); + Trace.TraceInformation($"SIMILARITY_KNN_TOP_K: {similarityKnnTopK}"); //builder.Services.AddSingleton(typeof(IPhotoStorage), storage); builder.Services.AddSingleton(typeof(ISolrSearchConfig), @@ -65,7 +65,7 @@ public static void Main(string[] args) longTermSearchRadiusKm, shortTermSearchRadiusKm, shortTermLength, - similarityThreshold, + similarityKnnTopK, reverseTimeGapLength )); } diff --git a/CardIndexRestAPI/StaticSolrSearchConfig.cs b/CardIndexRestAPI/StaticSolrSearchConfig.cs index 8aff6be..32ffcb7 100644 --- a/CardIndexRestAPI/StaticSolrSearchConfig.cs +++ b/CardIndexRestAPI/StaticSolrSearchConfig.cs @@ -21,7 +21,7 @@ public class StaticSolrSearchConfig : ISolrSearchConfig public TimeSpan ShortTermLength { get; private set; } - public double SimilarityThreshold { get; private set; } + public int SimilarityKnnTopK { get; private set; } public TimeSpan ReverseTimeGapLength { get; private set; } @@ -33,7 +33,7 @@ public StaticSolrSearchConfig( double longTermSearchRadiusKm, double shortTermSearchRadiusKm, TimeSpan shortTermLength, - double similarityThreshold, + int similarityKnnTopK, TimeSpan reverseTimeGapLength ) { this.SolrAddress = address; @@ -43,7 +43,7 @@ TimeSpan reverseTimeGapLength this.LongTermSearchRadiusKm = longTermSearchRadiusKm; this.ShortTermSearchRadiusKm = shortTermSearchRadiusKm; this.ShortTermLength = shortTermLength; - this.SimilarityThreshold = similarityThreshold; + this.SimilarityKnnTopK = similarityKnnTopK; this.ReverseTimeGapLength = reverseTimeGapLength; } }