Skip to content

Commit

Permalink
Knn based image similarity search (#4)
Browse files Browse the repository at this point in the history
* supporting knn image similarity search
  • Loading branch information
dgrechka authored Aug 29, 2022
1 parent 1346634 commit e85651d
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CardIndexRestAPI/CardIndexRestAPI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.2.3" />
</ItemGroup>

<ItemGroup>
<Folder Include="Properties\PublishProfiles\" />
</ItemGroup>

</Project>
79 changes: 60 additions & 19 deletions CardIndexRestAPI/Controllers/SolrProxyController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Net.Http;
using System.Web;
using static CardIndexRestAPI.DataSchema.Requests;
using System.Numerics;

namespace SolrAPI.Controllers
{
Expand Down Expand Up @@ -41,7 +42,8 @@ public SolrProxyController(ISolrSearchConfig solrAddress)
/// <param name="response">Where to write proxied data</param>
/// <returns></returns>
private static async Task ProxyHttpPost(string requestID, string requstURL, HttpContent? content, HttpResponse response) {
HttpClient client = new HttpClient();
HttpClient client = new HttpClient();
client.Timeout = TimeSpan.FromMinutes(20);
client.DefaultRequestHeaders.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("application/json"));

Trace.TraceInformation($"{requestID}: Issung request {requstURL}.");
Expand Down Expand Up @@ -69,6 +71,7 @@ private static async Task ProxyHttpPost(string requestID, string requstURL, Http
await outStream.DisposeAsync();
}

/*
[EnableCors]
[HttpPost("MatchedCardsSearch")]
public async Task MatchedCardsSearch([FromBody]GetMatchesRequest request)
Expand Down Expand Up @@ -143,6 +146,8 @@ public async Task MatchedCardsSearch([FromBody]GetMatchesRequest request)
}
}
*/

[EnableCors]
[HttpPost("MatchedImagesSearch")]
public async Task MatchedImagesSearch([FromBody] GetMatchesRequest request)
Expand All @@ -151,8 +156,15 @@ public async Task MatchedImagesSearch([FromBody] GetMatchesRequest request)
Trace.TraceInformation($"{requestHash}: Got request.");
try
{
string featureDims = String.Join(",", Enumerable.Range(0, request.Features.Length).Select(idx => $"{request.FeaturesIdent}_{idx}_d"));
string featuresTargetVal = String.Join(',', request.Features);

var features = request.Features.ToArray(); //.Take(900).ToArray();
double norm = Math.Sqrt(features.Sum(x => x * x));
features = features.Select(x => x / norm).ToArray();

Trace.TraceInformation($"Feature length is {features.Length}");
string featureDims = String.Join(",", Enumerable.Range(0, features.Length).Select(idx => $"{request.FeaturesIdent}_{idx}_d"));
//string featuresTargetVal = String.Join(',', features.Select(d => Math.Round(d,3)));
string featuresTargetVal = String.Join(',', features);

DateTime shortTermSearchStart = request.EventType switch
{
Expand Down Expand Up @@ -183,26 +195,57 @@ public async Task MatchedImagesSearch([FromBody] GetMatchesRequest request)
string longTermSpaceSpec = $"{{!geofilt sfield=location pt={request.Lat},{request.Lon} d={this.searchConfig.LongTermSearchRadiusKm}}}";
string longTermSearchTerm = $"{longTermTimeSpec} AND {longTermSpaceSpec}";

string typeSearchTerm = request.EventType switch
List<string> additionalFilter = new List<string>();
if (request.FilterFar ?? false)
additionalFilter.Add($"({longTermSearchTerm})");
if (request.FilterLongAgo ?? false)
additionalFilter.Add($"({shortTermSearchTerm})");
string additionalFiltersStr = String.Join(" OR ", additionalFilter);

string typeFilterTerm = request.EventType switch // this one inverts the specification
{
"Found" => "card_type:Lost",
"Lost" => "card_type:Found",
_ => throw new ArgumentException($"Unknown EventType: {request.EventType}")
};

string animalFilterTerm = request.Animal switch
{
"Cat" => "animal:Cat",
"Dog" => "animal:Dog",
_ => throw new ArgumentException($"Unknown Animal: {request.Animal}")
};

string solrFindLostRequest =
$"top(n={this.searchConfig.MaxReturnCount},having(select(search({this.searchConfig.ImagesCollectionName},q=\"animal:{request.Animal} AND {typeSearchTerm} AND (({shortTermSearchTerm})OR({longTermSearchTerm}))\",fl=\"id, event_time, {featureDims}\",sort=\"event_time asc\",qt=\"/export\"),id,cosineSimilarity(array({featureDims}), array({featuresTargetVal})) as similarity), gt(similarity, {this.searchConfig.SimilarityThreshold})),sort=\"similarity desc\")";
//Trace.TraceInformation($"{requestHash}: Got request. Issuing: {solrFindLostRequest}");

//solrFindLostRequest = "top(n=100,select(search(kashtankaimages,q=\"animal:Cat AND card_type:Lost AND ((event_time:[2019-06-19T21:00:00.0000000Z TO 2019-08-02T21:00:00.0000000Z] AND {!geofilt sfield=location pt=56.273015,43.93563 d=1000})OR(event_time:[ * TO 2019-08-02T21:00:00.0000000Z] AND {!geofilt sfield=location pt=56.273015,43.93563 d=20}))\",fl=\"id, event_time\",sort=\"event_time asc\",qt=\"/export\"),id,),sort=\"similarity desc\")";
string cardTypeAndAnimalFilter = $"{typeFilterTerm} AND {animalFilterTerm}";

// example from docs: https://solr.apache.org/guide/solr/latest/query-guide/dense-vector-search.html#query-time
// { !knn f = vector topK = 10}[1.0, 2.0, 3.0, 4.0]

const string embeddingName = "calvin_zhirui_embedding";

string similaritySearchExpr = $"{{!knn f={embeddingName} topK={searchConfig.SimilarityKnnTopK}}}[{featuresTargetVal}]";


Trace.TraceInformation($"sim search expr: ${similaritySearchExpr}");

List<KeyValuePair<string, string>> requestParams = new List<KeyValuePair<string, string>>(new KeyValuePair<string, string>[] {
//new KeyValuePair<string, string>("expr",solrFindLostRequest)
new KeyValuePair<string, string>("q",similaritySearchExpr),
new KeyValuePair<string, string>("fl",$"id,{embeddingName}"),
new KeyValuePair<string, string>("fq",cardTypeAndAnimalFilter),
new KeyValuePair<string, string>("rows",searchConfig.MaxReturnCount.ToString())
});

if (!string.IsNullOrEmpty(additionalFiltersStr))
requestParams.Add(new KeyValuePair<string, string>("fq", additionalFiltersStr));

// To avoid "URL is too long" we pass the request inside POST body
FormUrlEncodedContent requestContent = new FormUrlEncodedContent(new KeyValuePair<string, string>[] {
new KeyValuePair<string, string>("expr",solrFindLostRequest)
}); ;
FormUrlEncodedContent requestContent = new FormUrlEncodedContent(requestParams);


//return NotFound();

await ProxyHttpPost(requestHash, this.solrImagesStreamingExpressionsURL, requestContent, Response);
await ProxyHttpPost(requestHash, this.solrImagesSelectExpressionsURL, requestContent, Response);


Trace.TraceInformation($"{requestHash}: Transmitted successfully");
Expand Down Expand Up @@ -241,13 +284,11 @@ public async Task LatestCards([FromQuery]int maxCardsCount=10, [FromQuery] strin
if (!string.IsNullOrEmpty(typeConstraint)) {
requestParams.Add("fq", typeConstraint);
}

string queryStr =
string.Join('&', requestParams.Select(kvp => $"{HttpUtility.UrlEncode(kvp.Key)}={HttpUtility.UrlEncode(kvp.Value)}"));

string finalURL = $"{this.solrCardsSelectExpressionsURL}?{queryStr}";

FormUrlEncodedContent requestContent = new FormUrlEncodedContent(requestParams);

try {
await ProxyHttpPost("latest cards request", finalURL,null, Response);
await ProxyHttpPost("latest cards request", this.solrCardsSelectExpressionsURL, requestContent, Response);
}
catch (Exception err)
{
Expand Down
4 changes: 4 additions & 0 deletions CardIndexRestAPI/DataSchema/Requests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public class GetMatchesRequest {
public double[] Features { get; set; }
public string FeaturesIdent { get; set; }

public bool? FilterFar { get; set; }
public bool? FilterLongAgo { get; set; }

public override int GetHashCode()
{
Expand All @@ -24,6 +26,8 @@ public override int GetHashCode()
(this.Animal?.GetHashCode() ?? 0) ^
this.EventTime.GetHashCode() ^
this.EventType.GetHashCode() ^
(this.FilterFar?.GetHashCode() ?? 0) ^
(this.FilterLongAgo?.GetHashCode() ?? 0) ^
this.Features.Select(f => f.GetHashCode()).Aggregate(0, (acc, elem) => acc ^ elem) ^
this.Features.Length.GetHashCode() ^
(this.FeaturesIdent?.GetHashCode() ?? 0);
Expand Down
2 changes: 1 addition & 1 deletion CardIndexRestAPI/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ENV LONG_TERM_SEARCH_RADIUS_KM=20.0
ENV SHORT_TERM_SEARCH_RADIUS_KM=200.0
ENV SHORT_TERM_LENGTH_DAYS=30
ENV REVERSE_TIME_GAP_LENGTH_DAYS=14
ENV SIMILARITY_THRESHOLD=0.95
ENV SIMILARITY_KNN_TOP_K=200

COPY --from=publish /app/publish .
ENTRYPOINT ["dotnet", "CardIndexRestAPI.dll"]
2 changes: 1 addition & 1 deletion CardIndexRestAPI/ISolrSearchConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ public interface ISolrSearchConfig
public double ShortTermSearchRadiusKm { get; }
public TimeSpan ShortTermLength { get; }
public TimeSpan ReverseTimeGapLength { get; }
public double SimilarityThreshold { get; }
public int SimilarityKnnTopK { get;}
}
}
6 changes: 3 additions & 3 deletions CardIndexRestAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ public static void Main(string[] args)
TimeSpan reverseTimeGapLength = TimeSpan.FromDays(int.Parse(Environment.GetEnvironmentVariable("REVERSE_TIME_GAP_LENGTH_DAYS") ?? "14"));
Trace.TraceInformation($"REVERSE_TIME_GAP_LENGTH_DAYS: {reverseTimeGapLength}");

double similarityThreshold = double.Parse(Environment.GetEnvironmentVariable("SIMILARITY_THRESHOLD") ?? "0.1");
Trace.TraceInformation($"SIMILARITY_THRESHOLD: {similarityThreshold}");
int similarityKnnTopK = int.Parse(Environment.GetEnvironmentVariable("SIMILARITY_KNN_TOP_K") ?? "200");
Trace.TraceInformation($"SIMILARITY_KNN_TOP_K: {similarityKnnTopK}");

//builder.Services.AddSingleton(typeof(IPhotoStorage), storage);
builder.Services.AddSingleton(typeof(ISolrSearchConfig),
Expand All @@ -65,7 +65,7 @@ public static void Main(string[] args)
longTermSearchRadiusKm,
shortTermSearchRadiusKm,
shortTermLength,
similarityThreshold,
similarityKnnTopK,
reverseTimeGapLength
));
}
Expand Down
6 changes: 3 additions & 3 deletions CardIndexRestAPI/StaticSolrSearchConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class StaticSolrSearchConfig : ISolrSearchConfig

public TimeSpan ShortTermLength { get; private set; }

public double SimilarityThreshold { get; private set; }
public int SimilarityKnnTopK { get; private set; }

public TimeSpan ReverseTimeGapLength { get; private set; }

Expand All @@ -33,7 +33,7 @@ public StaticSolrSearchConfig(
double longTermSearchRadiusKm,
double shortTermSearchRadiusKm,
TimeSpan shortTermLength,
double similarityThreshold,
int similarityKnnTopK,
TimeSpan reverseTimeGapLength
) {
this.SolrAddress = address;
Expand All @@ -43,7 +43,7 @@ TimeSpan reverseTimeGapLength
this.LongTermSearchRadiusKm = longTermSearchRadiusKm;
this.ShortTermSearchRadiusKm = shortTermSearchRadiusKm;
this.ShortTermLength = shortTermLength;
this.SimilarityThreshold = similarityThreshold;
this.SimilarityKnnTopK = similarityKnnTopK;
this.ReverseTimeGapLength = reverseTimeGapLength;
}
}
Expand Down

0 comments on commit e85651d

Please sign in to comment.