Home > Net >  Distance calculation in mongodb aggregate using cosine
Distance calculation in mongodb aggregate using cosine

Time:11-04

I am saving face embedding as numpy array in mongodb and using this aggrigate to find distance between to array using euclidean algorithm.

Can someone please help to calculate distance using cosine?

Sample json document is shown below and aggregate should return only those documents whose distance is <= 0.68

{
  "_id": {
    "$oid": "635ff70a16dfa4cd45f02c43"
  },
  "img_path": "1_11",
  "name": "11",
  "embedding": [
    0.04153144732117653,
    -0.0008036745712161064,
    -0.003807373344898224,
    -0.11769875884056091,
    0.03676579147577286,
    0.09997286647558212,
    -0.044136010110378265,
    -0.1692838966846466,
    -0.003151319921016693,
    -0.03791208565235138,
    -0.010753434151411057,
    -0.024590950459241867,
    0.29803258180618286,
    -0.04285677522420883,
    0.20911607146263123,
    0.04455781728029251,
    0.029769904911518097,
    -0.33982840180397034,
    0.010117404162883759,
    -0.4239773750305176,
    -0.005369218066334724,
    0.0714929848909378,
    0.2586987018585205,
    0.007376951165497303,
    -0.03157464414834976,
    0.11055145412683487,
    -0.052830226719379425,
    -0.16745653748512268,
    0.06054156646132469,
    -0.3206060826778412,
    0.054714761674404144,
    -0.10260075330734253,
    0.2717891037464142,
    -0.16717249155044556,
    0.09896406531333923,
    -0.2952454090118408,
    0.010972617194056511,
    0.24918027222156525,
    -0.2690051198005676,
    0.2626166343688965,
    0.2875710725784302,
    0.13260763883590698,
    0.12771351635456085,
    -0.22898457944393158,
    0.18346519768238068,
    -0.06396391987800598,
    -0.09320224076509476,
    0.02307960018515587,
    -0.3165799081325531,
    0.007051767781376839,
    0.06508949398994446,
    -0.15390481054782867,
    0.07253721356391907,
    0.05360442399978638,
    0.02317194454371929,
    0.2602832019329071,
    0.32494112849235535,
    0.10228901356458664,
    -0.026188479736447334,
    0.051889386028051376,
    -0.17360231280326843,
    0.2001030296087265,
    0.11494665592908859,
    -0.1468532681465149,
    -0.3037929832935333,
    0.4243096113204956,
    -0.3967720568180084,
    -0.1580674648284912,
    0.1209937185049057,
    -0.3088080883026123,
    0.14958013594150543,
    -0.1499250829219818,
    0.1793043464422226,
    0.042404696345329285,
    -0.17440882325172424,
    0.014482134953141212,
    0.17418500781059265,
    -0.29781395196914673,
    0.3233387768268585,
    -0.13625966012477875,
    -0.2671341001987457,
    0.1924743503332138,
    -0.0009934399276971817,
    -0.13012878596782684,
    0.04334684833884239,
    -0.047992683947086334,
    -0.0871971845626831,
    0.026077959686517715,
    0.23131468892097473,
    -0.04128192365169525,
    0.11939074844121933,
    0.2669318914413452,
    0.02978256344795227,
    -0.07513333857059479,
    0.09071725606918335,
    0.14345180988311768,
    -0.2577393651008606,
    -0.1343279629945755,
    0.03614958003163338,
    0.04753677546977997,
    0.4799879491329193,
    0.12120816111564636,
    0.04913831502199173,
    -0.1472567766904831,
    -0.1521947830915451,
    -0.016198324039578438,
    0.0709092766046524,
    0.2530268430709839,
    -0.10402888804674149,
    0.12103180587291718,
    -0.00816013291478157,
    -0.15727344155311584,
    -0.09354865550994873,
    0.15803465247154236,
    -0.002220466732978821,
    -0.0023632189258933067,
    0.03150435537099838,
    -0.0761573389172554,
    -0.3728805482387543,
    0.05395852029323578,
    -0.13205842673778534,
    0.019016679376363754,
    0.5200108885765076,
    0.2782735824584961,
    0.08217129856348038,
    0.06879977881908417,
    -0.3019191026687622,
    -0.21047928929328918,
    -0.2397751361131668,
    -0.14399221539497375,
    0.18687033653259277,
    -0.28487658500671387,
    0.11619545519351959,
    -0.18031732738018036,
    -0.059407636523246765,
    -0.11267021298408508,
    -0.02284402772784233,
    -0.45863431692123413,
    -0.06318340450525284,
    0.11655210703611374,
    -0.34693512320518494,
    -0.14945799112319946,
    -0.03837423026561737,
    0.13326743245124817,
    -0.04826241731643677,
    0.0984693095088005,
    -0.21571457386016846,
    -0.005599251948297024,
    -0.1000245064496994,
    0.03078708052635193,
    0.2257998287677765,
    0.23468151688575745,
    0.24614854156970978,
    0.057032980024814606,
    0.02590012177824974,
    0.06637579947710037,
    -0.09635362774133682,
    0.024511300027370453,
    0.054878443479537964,
    -0.019001495093107224,
    0.03533126041293144,
    -0.14802871644496918,
    0.05799974128603935,
    0.17114050686359406,
    -0.10243572294712067,
    0.1828196793794632,
    -0.06769229471683502,
    0.006715534254908562,
    -0.0621270090341568,
    -0.1239347904920578,
    0.4451303482055664,
    0.2674187421798706,
    0.21410731971263885,
    -0.13395659625530243,
    0.12177252024412155,
    0.13320210576057434,
    0.07968433201313019,
    0.07145310938358307,
    0.13488343358039856,
    -0.3376474976539612,
    -0.027925914153456688,
    -0.01877274364233017,
    -0.055770669132471085,
    0.07248318195343018,
    -0.1985192596912384,
    0.41558143496513367,
    -0.21470016241073608,
    0.00180653459392488,
    0.01230315025895834,
    -0.25784197449684143,
    0.16818946599960327,
    -0.13869279623031616,
    0.05139467865228653,
    0.010087383911013603,
    0.21821117401123047,
    -0.096829354763031,
    0.2613685727119446,
    -0.0634373277425766,
    -0.054010000079870224,
    -0.1985006034374237,
    0.03603208810091019,
    0.010746903717517853,
    0.40761250257492065,
    -0.04444914311170578,
    0.018095390871167183,
    -0.15173248946666718,
    0.15368790924549103,
    -0.17171593010425568,
    -0.06542578339576721,
    0.08967467397451401,
    0.023094654083251953,
    -0.11160144954919815,
    0.012936883606016636,
    0.03222038224339485,
    0.16139109432697296,
    -0.0698033794760704,
    0.0025200583040714264,
    -0.13830213248729706,
    -0.19908757507801056,
    -0.04465571790933609,
    -0.3257773518562317,
    -0.24122636020183563,
    0.2163548767566681,
    0.19657863676548004,
    0.24990913271903992,
    0.47722360491752625,
    -0.06870221346616745,
    0.4060593247413635,
    0.01270704809576273,
    0.12326160073280334,
    0.16875870525836945,
    0.10108403116464615,
    -0.06470170617103577,
    0.3649567663669586,
    -0.02642560750246048,
    0.18347720801830292,
    -0.04590265080332756,
    0.10154377669095993,
    -0.23013350367546082,
    0.11789771169424057,
    -0.14196179807186127,
    0.3111759424209595,
    -0.26989394426345825,
    0.10450435429811478,
    -0.08256083726882935,
    -0.09166324138641357,
    -0.43762388825416565,
    -0.03300127387046814,
    0.0586063377559185,
    0.023209918290376663,
    -0.04786481708288193,
    0.1297772228717804,
    0.031117932870984077,
    0.11111341416835785,
    -0.14740192890167236,
    -0.38679540157318115,
    0.02582015097141266,
    -0.05029628798365593,
    -0.2217729240655899,
    0.12298854440450668,
    -0.09051433205604553,
    0.03927312046289444,
    -0.09138064086437225,
    0.015250100754201412,
    0.19535471498966217,
    -0.09158895909786224,
    0.0305732823908329,
    0.22398902475833893,
    -0.059374526143074036,
    -0.2645154595375061,
    -0.35829195380210876,
    -0.06549274921417236,
    0.4043419659137726,
    -0.004617571830749512,
    -0.45933690667152405,
    -0.10569997876882553,
    0.06339605897665024,
    -0.06815588474273682,
    0.16382789611816406,
    0.2128928303718567,
    0.17163580656051636,
    -0.2520802319049835,
    -0.14742188155651093,
    0.03737369552254677,
    -0.6033905744552612,
    0.031192412599921227,
    -0.21649636328220367,
    0.1696641445159912,
    -0.14388948678970337,
    -0.15856055915355682,
    -0.016064852476119995,
    0.42502662539482117,
    0.2341223508119583,
    0.03241221234202385,
    0.11778842657804489,
    0.1338769644498825,
    -0.13620787858963013,
    0.010683199390769005,
    -0.22845351696014404,
    -0.3415237069129944,
    0.22950437664985657,
    -0.26249340176582336,
    -0.08501540869474411,
    -0.08903054147958755,
    0.037564851343631744,
    0.23414592444896698,
    0.34675508737564087,
    0.02467748150229454,
    -0.10153255611658096,
    -0.026179887354373932,
    -0.22871042788028717,
    -0.27654820680618286,
    0.05612671375274658,
    -0.08376747369766235,
    0.1049552634358406,
    -0.013511593453586102,
    0.09128926694393158,
    -0.0011982081923633814,
    0.05062413960695267,
    0.08689695596694946,
    0.23952849209308624,
    0.22834563255310059,
    -0.09084956347942352,
    0.18998661637306213,
    -0.3503563106060028,
    -0.19745531678199768,
    -0.03925514966249466,
    0.403876394033432,
    -0.30546900629997253,
    -0.0010978113859891891,
    0.058379046618938446,
    0.11505014449357986,
    0.07647787034511566,
    0.09666424989700317,
    -0.4285615384578705,
    0.22888298332691193,
    -0.09557950496673584,
    -0.014434341341257095,
    -0.11273092031478882,
    0.2225649058818817,
    0.1214723289012909,
    0.04134359955787659,
    -0.03408576548099518,
    0.3014944791793823,
    -0.06966336816549301,
    -0.015556447207927704,
    -0.1288650631904602,
    0.32450148463249207,
    0.24157102406024933,
    0.22649994492530823,
    0.09195432811975479,
    0.1324455887079239,
    -0.1840941458940506,
    0.037664055824279785,
    -0.0247283224016428,
    0.047795332968235016,
    -0.3711877465248108,
    0.11318389326334,
    0.10009285062551498,
    0.1690656542778015,
    0.0007055314490571618,
    -0.2665793001651764,
    -0.16162775456905365,
    -0.2143493890762329,
    -0.14732767641544342,
    0.03997492045164108,
    -0.08071522414684296,
    0.025499414652585983,
    -0.18366828560829163,
    -0.0026306267827749252,
    0.08807510882616043,
    0.05053887516260147,
    0.22644345462322235,
    -0.2249600887298584,
    0.0743359848856926,
    -0.06598254293203354,
    -0.15972834825515747,
    0.2019716501235962,
    0.007057833950966597,
    0.15507261455059052,
    -0.1137743890285492,
    -0.37573352456092834,
    -0.22254572808742523,
    0.2919546365737915,
    0.10227206349372864,
    -0.0021838638931512833,
    -0.06583461910486221,
    0.02697696164250374,
    -0.16031339764595032,
    0.0013091331347823143,
    -0.38167423009872437,
    0.048076413571834564,
    -0.3681448698043823,
    -0.0686948150396347,
    -0.12983432412147522,
    0.03042253479361534,
    -0.053054507821798325,
    -0.014269194565713406,
    0.027273066341876984,
    -0.08195088058710098,
    0.10262835770845413,
    -0.1975705325603485,
    -0.0011348258703947067,
    -0.008084496483206749,
    0.06330059468746185,
    -0.20593810081481934,
    -0.1521030068397522,
    -0.27547234296798706,
    0.13705690205097198,
    -0.22010597586631775,
    -0.23979435861110687,
    -0.027724653482437134,
    -0.060340628027915955,
    -0.09296640753746033,
    -0.12447866052389145,
    0.1831706464290619,
    0.14675945043563843,
    0.12313313037157059,
    0.007889466360211372,
    -0.14576762914657593,
    -0.16882596909999847,
    0.017858413979411125,
    0.2485218197107315,
    -0.11284790188074112,
    0.3009180426597595,
    -0.16467604041099548,
    -0.29391059279441833,
    0.12656885385513306,
    -0.15594497323036194,
    0.2736760973930359,
    -0.13790778815746307,
    -0.13983769714832306,
    0.26664501428604126,
    0.0009564720094203949,
    -0.3380361795425415,
    0.04647413641214371,
    -0.14481918513774872,
    0.04400748014450073,
    -0.021950390189886093,
    0.11120294034481049,
    0.034938834607601166,
    0.24248531460762024,
    -0.048552513122558594,
    -0.039130110293626785,
    -0.05664297565817833,
    0.293057382106781,
    0.23749183118343353,
    0.061890747398138046,
    0.2265649139881134,
    -0.21199457347393036,
    -0.19780850410461426,
    -0.10714740306138992,
    0.018297407776117325,
    -0.18729877471923828,
    -0.03931368514895439,
    0.07213057577610016,
    -0.45697465538978577,
    -0.019952062517404556,
    -0.2227146327495575,
    0.01789798028767109,
    -0.05090702697634697,
    -0.012803144752979279,
    0.12090910971164703,
    0.27642205357551575,
    0.28505101799964905,
    0.10090625286102295,
    0.14638441801071167,
    -0.2750594913959503,
    0.19013990461826324,
    -0.09395234286785126,
    -0.08940427750349045,
    0.29363691806793213,
    0.02967078983783722,
    0.05469975620508194,
    -0.27136000990867615,
    -0.09450405836105347,
    -0.13537903130054474,
    -0.02756226621568203,
    0.2398587465286255,
    -0.03860166668891907,
    -0.2633676826953888,
    0.1544223576784134,
    0.2102378010749817,
    -0.055723778903484344,
    0.18494635820388794,
    0.02430533431470394,
    -0.0014444207772612572,
    0.01646110974252224,
    -0.2884419858455658,
    0.06975653767585754,
    0.14280545711517334,
    0.21855656802654266,
    -0.054865360260009766,
    -0.2664768397808075,
    0.15404537320137024,
    0.07058555632829666,
    -0.2564086318016052,
    0.025546366348862648,
    -0.18019306659698486,
    0.025199588388204575,
    -0.06954245269298553,
    -0.17014487087726593,
    0.24414581060409546,
    -0.2120237797498703,
    0.08856579661369324,
    -0.07644421607255936,
    0.11976826190948486,
    0.176508828997612,
    0.16417363286018372,
    -0.04531588405370712,
    -0.23630917072296143,
    0.05578522011637688
  ]
}

Thank you so much in advance. :)

    {
        "$addFields": { 
            "target_embedding": target_embedding
        }
    }
    , {"$unwind" : { "path" : "$embedding", "includeArrayIndex": "embedding_index"}}
    , {"$unwind" : { "path" : "$target_embedding", "includeArrayIndex": "target_index" }}
    , {
        "$project": {
            "img_path": 1,
            "embedding": 1,
            "name" : 1,
            "target_embedding": 1,
            "compare": {
                "$cmp": ['$embedding_index', '$target_index']
            }
        }
    }
    , {"$match": {"compare": 0}}
    , {
      "$group": {
        "_id": "$img_path",
        "name" : { "$first": "$name" },
        "distance": {
            "$sum": {
                "$pow": [{
                    "$subtract": ['$embedding', '$target_embedding']
                }, 2]
            }
        }
      }
    }
    , { 
        "$project": {
            "name" : 1,
            "_id": 1
            #, "distance": 1
            , "distance": {"$sqrt": "$distance"}
        }
    }, { 
        "$project": {
            "name" : 1,
            "_id": 1
            , "distance": 1
            , "cond": { "$lte": [ "$distance", 4.15 ] }
        }
    }
    , {"$match": {"cond": True}}
    , { "$sort" : { "distance" : 1 } }
    ] )

CodePudding user response:

I'm not familiar with the facial recognition domain so I don't know the significance of the 0.68 threshold, but it's easy to change the final "$match".

The pipeline below:

  1. Fetches a generated "target_embedding" (described below)
  2. Calculates cosine similarity parameters using "$reduce"
    • dot product
    • sum of squares for document embedding
    • sum of squares for target embedding
  3. Calculates cosine similarity as the dot product divided by the square root of the product of the two sums of squares
  4. Thresholds using "$match"

Rather than loading a 512-length array into mongoplayground.net, I used its mgodatagen configuration option to create two collections: "faces" with a hundred documents of embeddings, and "target" with a single document/embedding used as "target_embedding". So, you'll need to change the initial pipeline stages to match your inputs.

db.faces.aggregate([
  // go get the target embedding
  {
    "$lookup": {
      "from": "target",
      "pipeline": [],
      "as": "target"
    }
  },
  {
    "$set": {
      "target_embedding": {"$first": "$target.embedding"},
      "target": "$$REMOVE"
    }
  },
  // done getting target embedding
  //
  // calculate cosine similarity params
  {
    "$project": {
      "name": 1,
      "img_path": 1,
      "cos_sim_params": {
        "$reduce": {
          "input": {"$range": [0, {"$size": "$embedding"}]},
          "initialValue": {
            "dot_product": 0,
            "doc_2_sum": 0,
            "target_2_sum": 0
          },
          "in": {
            "$let": {
              "vars": {
                "doc_elem": {"$arrayElemAt": ["$embedding", "$$this"]},
                "target_elem": {"$arrayElemAt": ["$target_embedding", "$$this"]}
              },
              "in": {
                "dot_product": {
                  "$add": [
                    "$$value.dot_product",
                    {"$multiply": ["$$doc_elem", "$$target_elem"]}
                  ]
                },
                "doc_2_sum": {
                  "$add": [
                    "$$value.doc_2_sum",
                    {"$pow": ["$$doc_elem", 2]}
                  ]
                },
                "target_2_sum": {
                  "$add": [
                    "$$value.target_2_sum",
                    {"$pow": ["$$target_elem", 2]}
                  ]
                }
              }
            }
          }
        }
      }
    }
  },
  {
    "$project": {
      "name": 1,
      "img_path": 1,
      "cos_sim": {
        "$divide": [
          "$cos_sim_params.dot_product",
          {
            "$sqrt": {
              "$multiply": [
                "$cos_sim_params.doc_2_sum",
                "$cos_sim_params.target_2_sum"
              ]
            }
          }
        ]
      }
    }
  },
  {
    "$match": {
      "cos_sim": {"$lte": 0.68}
    }
  }
])

Try it on mongoplayground.net.

  • Related