From 854c69b0602e909d543c16424377579f0d46efac Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Wed, 30 Jan 2019 17:23:41 -0500 Subject: [PATCH 01/13] Changed script to work with Elastic 6.0.0. Inline scripts are now depreciated so using 'source' field instead --- README.md | 4 +- pom.xml | 12 +- .../plugin/VectorScoringPlugin.java | 154 +++++++++++++- .../script/VectorScoreScript.java | 193 ------------------ .../VectorScoringScriptEngineService.java | 78 ------- .../EmbeddedElasticsearchServer.java | 11 +- .../com/liorkn/elasticsearch/PluginTest.java | 2 +- 7 files changed, 159 insertions(+), 295 deletions(-) delete mode 100755 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java delete mode 100755 src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java diff --git a/README.md b/README.md index 9806c21..7088177 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ give it a try. ## Elasticsearch version -* Currently designed for Elasticsearch 5.6.0. +* Currently designed for Elasticsearch 6.0.0. * for Elasticsearch 5.2.2 use branch `es-5.2.2` * for Elasticsearch 2.4.4 use branch `es-2.4.4` @@ -146,7 +146,7 @@ func convertBase64ToArray(base64Str string) ([]float64, error) { "boost_mode": "replace", "script_score": { "script": { - "inline": "binary_vector_score", + "source": "binary_vector_score", "lang": "knn", "params": { "cosine": false, diff --git a/pom.xml b/pom.xml index d765043..bcdee7c 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 5.6.0 + 6.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 5.6.0 + 6.0.0 2.4 4.4.8 4.12 @@ -65,7 +65,7 @@ org.elasticsearch.plugin - transport-netty3-client + transport-netty4-client ${elasticsearch.version} test @@ -86,12 +86,6 @@ - - - - - - org.codelibs.elasticsearch.module lang-painless diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 88a3599..b0a666e 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,13 +13,25 @@ */ package com.liorkn.elasticsearch.plugin; -import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.util.Collection; +import java.util.Map; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; -import org.elasticsearch.script.ScriptEngineService; - +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; +import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -27,9 +39,141 @@ */ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { - public final ScriptEngineService getScriptEngineService(Settings settings) { - return new VectorScoringScriptEngineService(settings); + @Override + public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { + return new MyExpertScriptEngine(); } + /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ + private static class MyExpertScriptEngine implements ScriptEngine { + @Override + public String getType() { + return "knn"; + } + + private static final int DOUBLE_SIZE = 8; + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if ("binary_vector_score".equals(scriptSource)) { + SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { + final String field; + final boolean cosine; + { + if (p.containsKey("vector") == false) { + throw new IllegalArgumentException("Missing parameter [vector]"); + } + if (p.containsKey("field") == false) { + throw new IllegalArgumentException("Missing parameter [field]"); + } + if (p.containsKey("cosine") == false) { + throw new IllegalArgumentException("Missing parameter [cosine]"); + } + field = p.get("field").toString(); + cosine = (boolean) p.get("cosine"); + } + + final ArrayList searchVector = (ArrayList) p.get("vector"); + double magnitude; + { + if (cosine) { + // calc magnitude + double queryVectorNorm = 0.0; + // compute query inputVector norm once + for (Double v : this.searchVector) { + queryVectorNorm += v.doubleValue() * v.doubleValue(); + } + magnitude = Math.sqrt(queryVectorNorm); + } else { + magnitude = 0.0; + } + } + + @Override + public SearchScript newInstance(LeafReaderContext context) throws IOException { + return new SearchScript(p, lookup, context) { + BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); + int currentDocid = -1; + + @Override + public void setDocument(int docid) { + // Move to desired document + try { + docAccess.advanceExact(docid); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + currentDocid = docid; + } + + @Override + public double runAsDouble() { + if (currentDocid < 0) { + return 0.0; + } + //actually run scoring + final int size = searchVector.size(); + + try { + final byte[] bytes = docAccess.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size + if (len != size * DOUBLE_SIZE) { + return 0.0; + } + + final int position = input.getPosition(); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); + final double[] docVector = new double[size]; + doubleBuffer.get(docVector); + double docVectorNorm = 0.0f; + double score = 0; + for (int i = 0; i < size; i++) { + // doc inputVector norm + if(cosine) { + docVectorNorm += docVector[i]*docVector[i]; + } + // dot product + score += docVector[i] * searchVector.get(i).doubleValue(); + } + if(cosine) { + // cosine similarity score + if (docVectorNorm == 0 || magnitude == 0){ + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + return score; + } + } catch (Exception e) { + return 0; + } + } + }; + } + + @Override + public boolean needs_score() { + return false; + } + }; + return context.factoryClazz.cast(factory); + } + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + @Override + public void close() { + // optionally close resources + } + } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java deleted file mode 100755 index 0b87cf3..0000000 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ /dev/null @@ -1,193 +0,0 @@ -/* -Based on: https://discuss.elastic.co/t/vector-scoring/85227/4 -and https://github.com/MLnick/elasticsearch-vector-scoring - -another slower implementation using strings: https://github.com/ginobefun/elasticsearch-feature-vector-scoring - -storing arrays is no luck - lucine index doesn't keep the array members orders -https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html - -Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html - - - */ - -package com.liorkn.elasticsearch.script; - -import com.liorkn.elasticsearch.Util; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptException; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.util.ArrayList; -import java.util.Base64; -import java.util.Map; - -/** - * Script that scores documents based on cosine similarity embedding vectors. - */ -public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - - public static final String SCRIPT_NAME = "binary_vector_score"; - - private static final int DOUBLE_SIZE = 8; - - // the field containing the vectors to be scored against - public final String field; - - private int docId; - private BinaryDocValues binaryEmbeddingReader; - - private final double[] inputVector; - private final double magnitude; - - private final boolean cosine; - - @Override - public long runAsLong() { - return ((Number)this.run()).longValue(); - } - @Override - public double runAsDouble() { - return ((Number)this.run()).doubleValue(); - } - @Override - public void setNextVar(String name, Object value) {} - @Override - public void setDocument(int docId) { - this.docId = docId; - } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } - - - /** - * Factory that is registered in - * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)} - * method when the plugin is loaded. - */ - public static class Factory { - - /** - * This method is called for every search on every shard. - * - * @param params - * list of script parameters passed with the query - * @return new native script - */ - public ExecutableScript newScript(@Nullable Map params) throws ScriptException { - return new VectorScoreScript(params); - } - - /** - * Indicates if document scores may be needed by the produced scripts. - * - * @return {@code true} if scores are needed. - */ - public boolean needsScores() { - return false; - } - - } - - - /** - * Init - * @param params index that a scored are placed in this parameter. Initialize them here. - */ - @SuppressWarnings("unchecked") - public VectorScoreScript(Map params) { - final Object cosineBool = params.get("cosine"); - cosine = cosineBool != null ? - (boolean)cosineBool : - true; - - final Object field = params.get("field"); - if (field == null) - throw new IllegalArgumentException("binary_vector_score script requires field input"); - this.field = field.toString(); - - // get query inputVector - convert to primitive - final Object vector = params.get("vector"); - if(vector != null) { - final ArrayList tmp = (ArrayList) vector; - inputVector = new double[tmp.size()]; - for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i); - } - } else { - final Object encodedVector = params.get("encoded_vector"); - if(encodedVector == null) { - throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); - } - inputVector = Util.convertBase64ToArray((String) encodedVector); - } - - if(cosine) { - // calc magnitude - double queryVectorNorm = 0.0; - // compute query inputVector norm once - for (double v : inputVector) { - queryVectorNorm += v * v; - } - magnitude = Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0; - } - } - - - /** - * Called for each document - * @return cosine similarity of the current document against the input inputVector - */ - @Override - public final Object run() { - final int size = inputVector.length; - - final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read - if(len != size * DOUBLE_SIZE) { - return 0.0; - } - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * inputVector[i]; - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } - -} \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java deleted file mode 100755 index 58db087..0000000 --- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java +++ /dev/null @@ -1,78 +0,0 @@ -package com.liorkn.elasticsearch.service; - -import com.liorkn.elasticsearch.script.VectorScoreScript; -import org.apache.lucene.index.LeafReaderContext; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.component.AbstractComponent; -import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.script.CompiledScript; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptEngineService; -import org.elasticsearch.script.SearchScript; -import org.elasticsearch.search.lookup.SearchLookup; - -import java.io.IOException; -import java.util.Map; - -/** - * Created by Lior Knaany on 5/14/17. - */ -public class VectorScoringScriptEngineService extends AbstractComponent implements ScriptEngineService{ - - public static final String NAME = "knn"; - - @Inject - public VectorScoringScriptEngineService(Settings settings) { - super(settings); - } - - @Override - public Object compile(String scriptName, String scriptSource, Map params) { - return new VectorScoreScript.Factory(); - } - - - @Override - public boolean isInlineScriptEnabled() { - return true; - } - - @Override - public String getType() { - return NAME; - } - - @Override - public String getExtension() { - return NAME; - } - - @Override - public ExecutableScript executable(CompiledScript compiledScript, @Nullable Map vars) { - VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - return scriptFactory.newScript(vars); - } - - @Override - public SearchScript search(CompiledScript compiledScript, final SearchLookup lookup, @Nullable final Map vars) { - final VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - final VectorScoreScript script = (VectorScoreScript) scriptFactory.newScript(vars); - return new SearchScript() { - @Override - public LeafSearchScript getLeafSearchScript(LeafReaderContext context) throws IOException { - script.setBinaryEmbeddingReader(context.reader().getBinaryDocValues(script.field)); - return script; - } - @Override - public boolean needsScores() { - return scriptFactory.needsScores(); - } - }; - } - - @Override - public void close() { - } -} diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 5627240..1cbdb9c 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -10,7 +10,7 @@ import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; -import org.elasticsearch.transport.Netty3Plugin; +import org.elasticsearch.transport.Netty4Plugin; import java.io.File; import java.io.IOException; @@ -41,13 +41,10 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw Settings.Builder settings = Settings.builder() .put("http.enabled", "true") - .put("transport.type", "local") - .put("http.type", "netty3") + .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("script.inline", "on") - .put("node.max_local_storage_nodes", 10000) - .put("script.stored", "on"); + .put("node.max_local_storage_nodes", 10000); startNodeInAvailablePort(settings); } @@ -61,7 +58,7 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali settings.put("http.port", String.valueOf(this.port)); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty3Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index b95b65c..c134ea6 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -92,7 +92,7 @@ public void test() throws Exception { " \"boost_mode\": \"replace\"," + " \"script_score\": {" + " \"script\": {" + - " \"inline\": \"binary_vector_score\"," + + " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + " \"cosine\": false," + From 7cc89d0eb2ca1f67b076f12eb2db642736ad8f9f Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 31 Jan 2019 13:59:06 -0500 Subject: [PATCH 02/13] renamed engine appropriately --- .../com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index b0a666e..3f97f65 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -41,11 +41,11 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new MyExpertScriptEngine(); + return new VectorScoringPluginEngine(); } /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class MyExpertScriptEngine implements ScriptEngine { + private static class VectorScoringPluginEngine implements ScriptEngine { @Override public String getType() { return "knn"; From f688c3a5a74e63f7dfc4236054854d587e342f03 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:45:15 -0400 Subject: [PATCH 03/13] Broke it back into multiple files for simplicity, added support for encodedVector field, added runAsLong(), changed searchVector to be a primitive type, made SCRIPT_SOURCE private static class member --- .../java/com/liorkn/elasticsearch/Util.java | 35 ++-- .../engine/VectorScoringScriptEngine.java | 36 ++++ .../plugin/VectorScoringPlugin.java | 149 +---------------- .../script/VectorScoreScript.java | 157 ++++++++++++++++++ 4 files changed, 213 insertions(+), 164 deletions(-) create mode 100644 src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java create mode 100644 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java index de81af8..ed84f95 100644 --- a/src/main/java/com/liorkn/elasticsearch/Util.java +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -9,23 +9,22 @@ */ public class Util { - public static final double[] convertBase64ToArray(String base64Str) { - final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); + public static double[] convertBase64ToArray(String base64Str) { + final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); + final double[] dims = new double[doubleBuffer.capacity()]; + doubleBuffer.get(dims); + return dims; + } - final double[] dims = new double[doubleBuffer.capacity()]; - doubleBuffer.get(dims); - return dims; - } - - public static final String convertArrayToBase64(double[] array) { - final int capacity = 8 * array.length; - final ByteBuffer bb = ByteBuffer.allocate(capacity); - for (int i = 0; i < array.length; i++) { - bb.putDouble(array[i]); - } - bb.rewind(); - final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); - return new String(encodedBB.array()); - } + public static String convertArrayToBase64(double[] array) { + final int capacity = Double.BYTES * array.length; + final ByteBuffer bb = ByteBuffer.allocate(capacity); + for (double v : array) { + bb.putDouble((double) v); + } + bb.rewind(); + final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); + return new String(encodedBB.array()); + } } diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java new file mode 100644 index 0000000..6c00692 --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -0,0 +1,36 @@ +package com.liorkn.elasticsearch.engine; + +import com.liorkn.elasticsearch.script.VectorScoreScript; + +import java.util.Map; + +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; + +/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ +public class VectorScoringScriptEngine implements ScriptEngine { + + public static final String NAME = "knn"; + private static final String SCRIPT_SOURCE = "binary_vector_score"; + + @Override + public String getType() { + return NAME; + } + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if (!SCRIPT_SOURCE.equals(scriptSource)) { + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + return context.factoryClazz.cast(factory); + } +} diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 3f97f65..c279636 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,25 +13,15 @@ */ package com.liorkn.elasticsearch.plugin; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import com.liorkn.elasticsearch.engine.VectorScoringScriptEngine; + import java.util.Collection; -import java.util.Map; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.PostingsEnum; -import org.apache.lucene.index.Term; -import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; -import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -41,139 +31,6 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new VectorScoringPluginEngine(); - } - - /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class VectorScoringPluginEngine implements ScriptEngine { - @Override - public String getType() { - return "knn"; - } - - private static final int DOUBLE_SIZE = 8; - - @Override - public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - - if (context.equals(SearchScript.CONTEXT) == false) { - throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); - } - - // we use the script "source" as the script identifier - if ("binary_vector_score".equals(scriptSource)) { - SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { - final String field; - final boolean cosine; - { - if (p.containsKey("vector") == false) { - throw new IllegalArgumentException("Missing parameter [vector]"); - } - if (p.containsKey("field") == false) { - throw new IllegalArgumentException("Missing parameter [field]"); - } - if (p.containsKey("cosine") == false) { - throw new IllegalArgumentException("Missing parameter [cosine]"); - } - field = p.get("field").toString(); - cosine = (boolean) p.get("cosine"); - } - - final ArrayList searchVector = (ArrayList) p.get("vector"); - double magnitude; - { - if (cosine) { - // calc magnitude - double queryVectorNorm = 0.0; - // compute query inputVector norm once - for (Double v : this.searchVector) { - queryVectorNorm += v.doubleValue() * v.doubleValue(); - } - magnitude = Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0; - } - } - - @Override - public SearchScript newInstance(LeafReaderContext context) throws IOException { - return new SearchScript(p, lookup, context) { - BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); - int currentDocid = -1; - - @Override - public void setDocument(int docid) { - // Move to desired document - try { - docAccess.advanceExact(docid); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - currentDocid = docid; - } - - @Override - public double runAsDouble() { - if (currentDocid < 0) { - return 0.0; - } - //actually run scoring - final int size = searchVector.size(); - - try { - final byte[] bytes = docAccess.binaryValue().bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size - if (len != size * DOUBLE_SIZE) { - return 0.0; - } - - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * searchVector.get(i).doubleValue(); - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } catch (Exception e) { - return 0; - } - } - }; - } - - @Override - public boolean needs_score() { - return false; - } - }; - return context.factoryClazz.cast(factory); - } - throw new IllegalArgumentException("Unknown script name " + scriptSource); - } - - @Override - public void close() { - // optionally close resources - } + return new VectorScoringScriptEngine(); } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java new file mode 100644 index 0000000..b682acf --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -0,0 +1,157 @@ +package com.liorkn.elasticsearch.script; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.ByteArrayDataInput; +import org.elasticsearch.script.SearchScript; +import org.elasticsearch.search.lookup.SearchLookup; + +import com.liorkn.elasticsearch.Util; + +public final class VectorScoreScript extends SearchScript { + + private BinaryDocValues binaryEmbeddingReader; + + private final String field; + private final boolean cosine; + + private final double[] inputVector; + private final double magnitude; + + @Override + public long runAsLong() { + return (long) runAsDouble(); + } + + @Override + public double runAsDouble() { + try { + final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + + final int len = input.readVInt(); + // in case vector is of different size + if (len != inputVector.length * Double.BYTES) { + return 0.0; + } + + float score = 0; + + if (cosine) { + double docVectorNorm = 0.0f; + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + docVectorNorm += v * v; // inputVector norm + score += v * inputVector[i]; // dot product + } + + if (docVectorNorm == 0 || magnitude == 0) { + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + score += v * inputVector[i]; // dot product + } + + return score; + } + } catch (Exception e) { + return 0.0; + } + } + + @Override + public void setDocument(int docId) { + try { + this.binaryEmbeddingReader.advanceExact(docId); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { + if(binaryEmbeddingReader == null) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + this.binaryEmbeddingReader = binaryEmbeddingReader; + } + + @SuppressWarnings("unchecked") + public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, lookup, leafContext); + + final Object cosineBool = params.get("cosine"); + this.cosine = cosineBool != null ? + (boolean)cosineBool : + true; + + final Object field = params.get("field"); + if (field == null) + throw new IllegalArgumentException("binary_vector_score script requires field input"); + this.field = field.toString(); + + // get query inputVector - convert to primitive + final Object vector = params.get("vector"); + if(vector != null) { + final ArrayList tmp = (ArrayList) vector; + inputVector = new double[tmp.size()]; + for (int i = 0; i < inputVector.length; i++) { + inputVector[i] = tmp.get(i).doubleValue(); + } + } else { + final Object encodedVector = params.get("encoded_vector"); + if(encodedVector == null) { + throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); + } + inputVector = Util.convertBase64ToArray((String) encodedVector); + } + + if (this.cosine) { + // calc magnitude + double queryVectorNorm = 0.0f; + // compute query inputVector norm once + for (double v: this.inputVector) { + queryVectorNorm += v * v; + } + this.magnitude = (double) Math.sqrt(queryVectorNorm); + } else { + this.magnitude = 0.0f; + } + + try { + this.binaryEmbeddingReader = leafContext.reader().getBinaryDocValues(this.field); + } catch (IOException e) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + } + + public static class VectorScoreScriptFactory implements LeafFactory { + private final Map params; + private final SearchLookup lookup; + + public VectorScoreScriptFactory(Map params, SearchLookup lookup) { + this.params = params; + this.lookup = lookup; + } + + public boolean needs_score() { + return false; + } + + @Override + public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + return new VectorScoreScript(this.params, this.lookup, ctx); + } + } +} \ No newline at end of file From 2aa0f87d74b7b0606285b44f9f02ff7848ce1b7c Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:48:28 -0400 Subject: [PATCH 04/13] Removed unused import --- .../java/com/liorkn/elasticsearch/script/VectorScoreScript.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index b682acf..f2bb77c 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -3,7 +3,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Map; import org.apache.lucene.index.BinaryDocValues; From e1bae87588d1f7866ec1b945ee1dae3eacfbc3fb Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 13:33:52 -0400 Subject: [PATCH 05/13] Added support for ES 7.x --- pom.xml | 9 ++++---- .../engine/VectorScoringScriptEngine.java | 6 ++--- .../script/VectorScoreScript.java | 22 +++++-------------- 3 files changed, 12 insertions(+), 25 deletions(-) diff --git a/pom.xml b/pom.xml index bcdee7c..1d3ad8e 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 6.0.0 + 7.0.0-rc2 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 6.0.0 + 7.0.0-rc2 2.4 4.4.8 4.12 @@ -39,13 +39,13 @@ org.apache.logging.log4j log4j-api - 2.7 + 2.11.1 test org.apache.logging.log4j log4j-core - 2.7 + 2.11.1 test @@ -108,7 +108,6 @@ org.elasticsearch.client elasticsearch-rest-client ${elasticsearch.version} - org.apache.httpcomponents diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java index 6c00692..0f06f86 100644 --- a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -6,7 +6,7 @@ import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; +import org.elasticsearch.script.ScoreScript; /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ public class VectorScoringScriptEngine implements ScriptEngine { @@ -21,7 +21,7 @@ public String getType() { @Override public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - if (context.equals(SearchScript.CONTEXT) == false) { + if (context.equals(ScoreScript.CONTEXT) == false) { throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } @@ -30,7 +30,7 @@ public T compile(String scriptName, String scriptSource, ScriptContext co throw new IllegalArgumentException("Unknown script name " + scriptSource); } - SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + ScoreScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; return context.factoryClazz.cast(factory); } } diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index f2bb77c..04ce38d 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -8,12 +8,12 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.script.SearchScript; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.search.lookup.SearchLookup; import com.liorkn.elasticsearch.Util; -public final class VectorScoreScript extends SearchScript { +public final class VectorScoreScript extends ScoreScript { private BinaryDocValues binaryEmbeddingReader; @@ -22,14 +22,9 @@ public final class VectorScoreScript extends SearchScript { private final double[] inputVector; private final double magnitude; - - @Override - public long runAsLong() { - return (long) runAsDouble(); - } - + @Override - public double runAsDouble() { + public double execute() { try { final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; final ByteArrayDataInput input = new ByteArrayDataInput(bytes); @@ -78,13 +73,6 @@ public void setDocument(int docId) { throw new UncheckedIOException(e); } } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } @SuppressWarnings("unchecked") public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { @@ -149,7 +137,7 @@ public boolean needs_score() { } @Override - public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { return new VectorScoreScript(this.params, this.lookup, ctx); } } From a77c72c4c531cb810306f909d1cff316616a6cc3 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Mon, 22 Apr 2019 14:48:37 -0400 Subject: [PATCH 06/13] Added support for ES 7.0.0 --- pom.xml | 4 +- src/main/assemblies/plugin.xml | 2 +- .../script/VectorScoreScript.java | 3 +- .../resources/plugin-descriptor.properties | 1 - .../EmbeddedElasticsearchServer.java | 15 ++++-- .../com/liorkn/elasticsearch/NodeExt.java | 8 ++- .../com/liorkn/elasticsearch/PluginTest.java | 51 +++++++++---------- 7 files changed, 43 insertions(+), 41 deletions(-) diff --git a/pom.xml b/pom.xml index 1d3ad8e..828ffd1 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 7.0.0-rc2 + 7.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 7.0.0-rc2 + 7.0.0 2.4 4.4.8 4.12 diff --git a/src/main/assemblies/plugin.xml b/src/main/assemblies/plugin.xml index fd186c6..b1c09fb 100755 --- a/src/main/assemblies/plugin.xml +++ b/src/main/assemblies/plugin.xml @@ -4,7 +4,7 @@ zip - true + false elasticsearch diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index 04ce38d..9c9ee85 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -23,8 +23,7 @@ public final class VectorScoreScript extends ScoreScript { private final double[] inputVector; private final double magnitude; - @Override - public double execute() { + public double execute() { try { final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; final ByteArrayDataInput input = new ByteArrayDataInput(bytes); diff --git a/src/main/resources/plugin-descriptor.properties b/src/main/resources/plugin-descriptor.properties index eed857f..15eae61 100755 --- a/src/main/resources/plugin-descriptor.properties +++ b/src/main/resources/plugin-descriptor.properties @@ -1,7 +1,6 @@ name=${project.name} description=${project.description} version=${project.version} -jvm=true classname=${elasticsearch.plugin.classname} java.version=1.8 elasticsearch.version=${elasticsearch.version} \ No newline at end of file diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 1cbdb9c..e636a7b 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -5,8 +5,10 @@ import org.apache.commons.io.FileUtils; import org.elasticsearch.client.Client; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.env.Environment; import org.elasticsearch.http.BindHttpException; import org.elasticsearch.index.reindex.ReindexPlugin; +import org.elasticsearch.node.InternalSettingsPreparer; import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; @@ -17,6 +19,9 @@ import java.util.Arrays; import java.util.Random; +import static java.util.Collections.emptyMap; + + public class EmbeddedElasticsearchServer { private static final Random RANDOM = new Random(); private static final String DEFAULT_DATA_DIRECTORY = "target/elasticsearch-data"; @@ -40,11 +45,11 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw this.port = port; Settings.Builder settings = Settings.builder() - .put("http.enabled", "true") .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("node.max_local_storage_nodes", 10000); + .put("node.max_local_storage_nodes", 10000) + .put("node.name", "test"); startNodeInAvailablePort(settings); } @@ -56,9 +61,11 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali while(!success) { try { settings.put("http.port", String.valueOf(this.port)); - + + Settings envSettings = settings.build(); + Environment env = InternalSettingsPreparer.prepareEnvironment(envSettings, emptyMap(), null, null); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(env, Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/NodeExt.java b/src/test/java/com/liorkn/elasticsearch/NodeExt.java index 5c79d59..b1b72fe 100644 --- a/src/test/java/com/liorkn/elasticsearch/NodeExt.java +++ b/src/test/java/com/liorkn/elasticsearch/NodeExt.java @@ -1,7 +1,6 @@ package com.liorkn.elasticsearch; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.node.InternalSettingsPreparer; +import org.elasticsearch.env.Environment; import org.elasticsearch.node.Node; import org.elasticsearch.plugins.Plugin; @@ -12,9 +11,8 @@ */ public class NodeExt extends Node { - public NodeExt(Settings preparedSettings, Collection> classpathPlugins) { - super(InternalSettingsPreparer.prepareEnvironment(preparedSettings, null), - classpathPlugins); + public NodeExt(Environment env, Collection> classpathPlugins) { + super(env, classpathPlugins, false); } } diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index c134ea6..e859be2 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -1,21 +1,16 @@ package com.liorkn.elasticsearch; import com.fasterxml.jackson.core.JsonParser; + import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.http.HttpHost; -import org.apache.http.entity.ContentType; -import org.apache.http.entity.StringEntity; -import org.apache.http.nio.entity.NStringEntity; -import org.apache.http.util.EntityUtils; -import org.elasticsearch.client.Response; +import org.elasticsearch.client.Request; import org.elasticsearch.client.RestClient; import org.junit.AfterClass; -import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -34,29 +29,30 @@ public static void init() throws Exception { // delete test index if exists try { - esClient.performRequest("DELETE", "/test", Collections.emptyMap()); + Request deleteRequest = new Request("DELETE", "/test"); + esClient.performRequest(deleteRequest); } catch (Exception e) {} // create test index String mappingJson = "{\n" + " \"mappings\": {\n" + - " \"type\": {\n" + - " \"properties\": {\n" + - " \"embedding_vector\": {\n" + - " \"doc_values\": true,\n" + - " \"type\": \"binary\"\n" + - " },\n" + - " \"job_id\": {\n" + - " \"type\": \"long\"\n" + - " },\n" + - " \"vector\": {\n" + - " \"type\": \"float\"\n" + - " }\n" + + " \"properties\": {\n" + + " \"embedding_vector\": {\n" + + " \"doc_values\": true,\n" + + " \"type\": \"binary\"\n" + + " },\n" + + " \"job_id\": {\n" + + " \"type\": \"long\"\n" + + " },\n" + + " \"vector\": {\n" + + " \"type\": \"float\"\n" + " }\n" + " }\n" + " }\n" + "}"; - esClient.performRequest("PUT", "/test", Collections.emptyMap(), new NStringEntity(mappingJson, ContentType.APPLICATION_JSON)); + Request putRequest = new Request("PUT", "/test"); + putRequest.setJsonEntity(mappingJson); + esClient.performRequest(putRequest); } public static final ObjectMapper mapper = new ObjectMapper(); @@ -68,8 +64,6 @@ public static void init() throws Exception { @Test public void test() throws Exception { - final Map params = new HashMap<>(); - params.put("refresh", "true"); final ObjectMapper mapper = new ObjectMapper(); final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}), new TestObject(2, new double[] {0.2, 0.6, 0.99})}; @@ -78,7 +72,10 @@ public void test() throws Exception { final TestObject t = objs[i]; final String json = mapper.writeValueAsString(t); System.out.println(json); - final Response put = esClient.performRequest("PUT", "/test/type/" + t.jobId, params, new StringEntity(json, ContentType.APPLICATION_JSON)); + Request indexRequest = new Request("POST", "/test/_doc/" + t.jobId); + indexRequest.addParameter("refresh", "true"); + indexRequest.setJsonEntity(json); + final Response put = esClient.performRequest(indexRequest); System.out.println(put); System.out.println(EntityUtils.toString(put.getEntity())); final int statusCode = put.getStatusLine().getStatusCode(); @@ -107,12 +104,14 @@ public void test() throws Exception { " }," + " \"size\": 100" + "}"; - final Response res = esClient.performRequest("POST", "/test/_search", Collections.emptyMap(), new NStringEntity(body, ContentType.APPLICATION_JSON)); + Request searchRequest = new Request("POST", "/test/_search"); + searchRequest.setJsonEntity(body); + final Response res = esClient.performRequest(searchRequest); System.out.println(res); final String resBody = EntityUtils.toString(res.getEntity()); System.out.println(resBody); Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode()); - Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length)); + Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":{\"value\":" + objs.length)); } @AfterClass From 63d84a3ebe89ba012358b06a1db85bce8e549273 Mon Sep 17 00:00:00 2001 From: Lior Knaany Date: Tue, 23 Apr 2019 21:36:12 +0300 Subject: [PATCH 07/13] changed a double usage in a test to float. as part of the move to floats (for performance reasons), this was missed --- src/main/java/com/liorkn/elasticsearch/Util.java | 2 +- src/test/java/com/liorkn/elasticsearch/PluginTest.java | 4 ++-- src/test/java/com/liorkn/elasticsearch/TestObject.java | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java index 045a0d0..53cf1f6 100644 --- a/src/main/java/com/liorkn/elasticsearch/Util.java +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -18,7 +18,7 @@ public static float[] convertBase64ToArray(String base64Str) { return dims; } - public static String convertArrayToBase64(double[] array) { + public static String convertArrayToBase64(float[] array) { final int capacity = Float.BYTES * array.length; final ByteBuffer bb = ByteBuffer.allocate(capacity); for (double v : array) { diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index 417db02..3a33746 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -72,8 +72,8 @@ public static void init() throws Exception { public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); - final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}), - new TestObject(2, new double[] {0.2, 0.6, 0.99})}; + final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), + new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; for (int i = 0; i < objs.length; i++) { final TestObject t = objs[i]; diff --git a/src/test/java/com/liorkn/elasticsearch/TestObject.java b/src/test/java/com/liorkn/elasticsearch/TestObject.java index f37d98a..a8d0391 100644 --- a/src/test/java/com/liorkn/elasticsearch/TestObject.java +++ b/src/test/java/com/liorkn/elasticsearch/TestObject.java @@ -10,7 +10,7 @@ public class TestObject { int jobId; String embeddingVector; - double[] vector; + float[] vector; public int getJobId() { return jobId; @@ -20,11 +20,11 @@ public String getEmbeddingVector() { return embeddingVector; } - public double[] getVector() { + public float[] getVector() { return vector; } - public TestObject(int jobId, double[] vector) { + public TestObject(int jobId, float[] vector) { this.jobId = jobId; this.vector = vector; this.embeddingVector = Util.convertArrayToBase64(vector); From 063240e10577541a272f6fb625909a7a01e253ed Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Wed, 30 Jan 2019 17:23:41 -0500 Subject: [PATCH 08/13] Changed script to work with Elastic 6.0.0. Inline scripts are now depreciated so using 'source' field instead --- README.md | 4 +- pom.xml | 12 +- .../plugin/VectorScoringPlugin.java | 154 ++++++++++++++- .../script/VectorScoreScript.java | 179 ------------------ .../VectorScoringScriptEngineService.java | 77 -------- .../EmbeddedElasticsearchServer.java | 11 +- .../com/liorkn/elasticsearch/PluginTest.java | 2 +- 7 files changed, 159 insertions(+), 280 deletions(-) delete mode 100755 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java delete mode 100755 src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java diff --git a/README.md b/README.md index 9ad9f9b..5a8f048 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ give it a try. ## Elasticsearch version -* Currently designed for Elasticsearch 5.6.0. +* Currently designed for Elasticsearch 6.0.0. * for Elasticsearch 5.2.2 use branch `es-5.2.2` * for Elasticsearch 2.4.4 use branch `es-2.4.4` @@ -147,7 +147,7 @@ func convertBase64ToArray(base64Str string) ([]float32, error) { "boost_mode": "replace", "script_score": { "script": { - "inline": "binary_vector_score", + "source": "binary_vector_score", "lang": "knn", "params": { "cosine": false, diff --git a/pom.xml b/pom.xml index 7434047..79d7430 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 5.6.0 + 6.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 5.6.0 + 6.0.0 2.4 4.4.8 4.12 @@ -65,7 +65,7 @@ org.elasticsearch.plugin - transport-netty3-client + transport-netty4-client ${elasticsearch.version} test @@ -86,12 +86,6 @@ - - - - - - org.codelibs.elasticsearch.module lang-painless diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 88a3599..b0a666e 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,13 +13,25 @@ */ package com.liorkn.elasticsearch.plugin; -import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.util.Collection; +import java.util.Map; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; -import org.elasticsearch.script.ScriptEngineService; - +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; +import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -27,9 +39,141 @@ */ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { - public final ScriptEngineService getScriptEngineService(Settings settings) { - return new VectorScoringScriptEngineService(settings); + @Override + public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { + return new MyExpertScriptEngine(); } + /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ + private static class MyExpertScriptEngine implements ScriptEngine { + @Override + public String getType() { + return "knn"; + } + + private static final int DOUBLE_SIZE = 8; + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if ("binary_vector_score".equals(scriptSource)) { + SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { + final String field; + final boolean cosine; + { + if (p.containsKey("vector") == false) { + throw new IllegalArgumentException("Missing parameter [vector]"); + } + if (p.containsKey("field") == false) { + throw new IllegalArgumentException("Missing parameter [field]"); + } + if (p.containsKey("cosine") == false) { + throw new IllegalArgumentException("Missing parameter [cosine]"); + } + field = p.get("field").toString(); + cosine = (boolean) p.get("cosine"); + } + + final ArrayList searchVector = (ArrayList) p.get("vector"); + double magnitude; + { + if (cosine) { + // calc magnitude + double queryVectorNorm = 0.0; + // compute query inputVector norm once + for (Double v : this.searchVector) { + queryVectorNorm += v.doubleValue() * v.doubleValue(); + } + magnitude = Math.sqrt(queryVectorNorm); + } else { + magnitude = 0.0; + } + } + + @Override + public SearchScript newInstance(LeafReaderContext context) throws IOException { + return new SearchScript(p, lookup, context) { + BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); + int currentDocid = -1; + + @Override + public void setDocument(int docid) { + // Move to desired document + try { + docAccess.advanceExact(docid); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + currentDocid = docid; + } + + @Override + public double runAsDouble() { + if (currentDocid < 0) { + return 0.0; + } + //actually run scoring + final int size = searchVector.size(); + + try { + final byte[] bytes = docAccess.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size + if (len != size * DOUBLE_SIZE) { + return 0.0; + } + + final int position = input.getPosition(); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); + final double[] docVector = new double[size]; + doubleBuffer.get(docVector); + double docVectorNorm = 0.0f; + double score = 0; + for (int i = 0; i < size; i++) { + // doc inputVector norm + if(cosine) { + docVectorNorm += docVector[i]*docVector[i]; + } + // dot product + score += docVector[i] * searchVector.get(i).doubleValue(); + } + if(cosine) { + // cosine similarity score + if (docVectorNorm == 0 || magnitude == 0){ + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + return score; + } + } catch (Exception e) { + return 0; + } + } + }; + } + + @Override + public boolean needs_score() { + return false; + } + }; + return context.factoryClazz.cast(factory); + } + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + @Override + public void close() { + // optionally close resources + } + } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java deleted file mode 100755 index c9e0401..0000000 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ /dev/null @@ -1,179 +0,0 @@ -/* -Based on: https://discuss.elastic.co/t/vector-scoring/85227/4 -and https://github.com/MLnick/elasticsearch-vector-scoring - -another slower implementation using strings: https://github.com/ginobefun/elasticsearch-feature-vector-scoring - -storing arrays is no luck - lucine index doesn't keep the array members orders -https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html - -Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html - */ - -package com.liorkn.elasticsearch.script; - -import com.liorkn.elasticsearch.Util; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptException; - -import java.util.ArrayList; -import java.util.Map; - - -/** - * Script that scores documents based on cosine similarity embedding vectors. - */ -public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - - // the field containing the vectors to be scored against - public final String field; - - private int docId; - private BinaryDocValues binaryEmbeddingReader; - - private final float[] inputVector; - private final float magnitude; - - private final boolean cosine; - - @Override - public final Object run() { - return runAsDouble(); - } - - @Override - public long runAsLong() { - return (long) runAsDouble(); - } - - /** - * Called for each document - * @return cosine similarity of the current document against the input inputVector - */ - @Override - public double runAsDouble() { - final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - - // MUST appear hear since it affect the next calls - input.readVInt(); // returns the number of values which should be 1 - input.readVInt(); // returns the number of bytes to read - - float score = 0; - - if(cosine) { - float docVectorNorm = 0.0f; - - for (int i = 0; i < inputVector.length; i++) { - float v = Float.intBitsToFloat(input.readInt()); - docVectorNorm += v * v; // inputVector norm - score += v * inputVector[i]; // dot product - } - - if (docVectorNorm == 0 || magnitude == 0) { - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - for (int i = 0; i < inputVector.length; i++) { - float v = Float.intBitsToFloat(input.readInt()); - score += v * inputVector[i]; // dot product - } - - return score; - } - } - - @Override - public void setNextVar(String name, Object value) {} - - @Override - public void setDocument(int docId) { - this.docId = docId; - } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } - - /** - * Factory that is registered in - * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)} - * method when the plugin is loaded. - */ - public static class Factory { - - /** - * This method is called for every search on every shard. - * - * @param params - * list of script parameters passed with the query - * @return new native script - */ - public ExecutableScript newScript(@Nullable Map params) throws ScriptException { - return new VectorScoreScript(params); - } - - /** - * Indicates if document scores may be needed by the produced scripts. - * - * @return {@code true} if scores are needed. - */ - public boolean needsScores() { - return false; - } - } - - /** - * Init - * @param params index that a scored are placed in this parameter. Initialize them here. - */ - @SuppressWarnings("unchecked") - public VectorScoreScript(Map params) { - final Object cosineBool = params.get("cosine"); - cosine = cosineBool != null ? - (boolean)cosineBool : - true; - - final Object field = params.get("field"); - if (field == null) - throw new IllegalArgumentException("binary_vector_score script requires field input"); - this.field = field.toString(); - - // get query inputVector - convert to primitive - final Object vector = params.get("vector"); - if(vector != null) { - final ArrayList tmp = (ArrayList) vector; - inputVector = new float[tmp.size()]; - for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i).floatValue(); - } - } else { - final Object encodedVector = params.get("encoded_vector"); - if(encodedVector == null) { - throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); - } - inputVector = Util.convertBase64ToArray((String) encodedVector); - } - - if(cosine) { - // calc magnitude - float queryVectorNorm = 0.0f; - // compute query inputVector norm once - for (float v: inputVector) { - queryVectorNorm += v * v; - } - magnitude = (float) Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0f; - } - } -} \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java deleted file mode 100755 index 269df8a..0000000 --- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java +++ /dev/null @@ -1,77 +0,0 @@ -package com.liorkn.elasticsearch.service; - -import com.liorkn.elasticsearch.script.VectorScoreScript; -import org.apache.lucene.index.LeafReaderContext; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.component.AbstractComponent; -import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.script.CompiledScript; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptEngineService; -import org.elasticsearch.script.SearchScript; -import org.elasticsearch.search.lookup.SearchLookup; - -import java.io.IOException; -import java.util.Map; - -/** - * Created by Lior Knaany on 5/14/17. - */ -public class VectorScoringScriptEngineService extends AbstractComponent implements ScriptEngineService{ - - public static final String NAME = "knn"; - - @Inject - public VectorScoringScriptEngineService(Settings settings) { - super(settings); - } - - @Override - public Object compile(String scriptName, String scriptSource, Map params) { - return new VectorScoreScript.Factory(); - } - - @Override - public boolean isInlineScriptEnabled() { - return true; - } - - @Override - public String getType() { - return NAME; - } - - @Override - public String getExtension() { - return NAME; - } - - @Override - public ExecutableScript executable(CompiledScript compiledScript, @Nullable Map vars) { - VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - return scriptFactory.newScript(vars); - } - - @Override - public SearchScript search(CompiledScript compiledScript, final SearchLookup lookup, @Nullable final Map vars) { - final VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - final VectorScoreScript script = (VectorScoreScript) scriptFactory.newScript(vars); - return new SearchScript() { - @Override - public LeafSearchScript getLeafSearchScript(LeafReaderContext context) throws IOException { - script.setBinaryEmbeddingReader(context.reader().getBinaryDocValues(script.field)); - return script; - } - @Override - public boolean needsScores() { - return scriptFactory.needsScores(); - } - }; - } - - @Override - public void close() { - } -} diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 5627240..1cbdb9c 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -10,7 +10,7 @@ import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; -import org.elasticsearch.transport.Netty3Plugin; +import org.elasticsearch.transport.Netty4Plugin; import java.io.File; import java.io.IOException; @@ -41,13 +41,10 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw Settings.Builder settings = Settings.builder() .put("http.enabled", "true") - .put("transport.type", "local") - .put("http.type", "netty3") + .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("script.inline", "on") - .put("node.max_local_storage_nodes", 10000) - .put("script.stored", "on"); + .put("node.max_local_storage_nodes", 10000); startNodeInAvailablePort(settings); } @@ -61,7 +58,7 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali settings.put("http.port", String.valueOf(this.port)); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty3Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index 3a33746..b7d5747 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -93,7 +93,7 @@ public void test() throws Exception { " \"boost_mode\": \"replace\"," + " \"script_score\": {" + " \"script\": {" + - " \"inline\": \"binary_vector_score\"," + + " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + " \"cosine\": true," + From 1bccb562b00a3182feb501bb179172c15e0ea479 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 31 Jan 2019 13:59:06 -0500 Subject: [PATCH 09/13] renamed engine appropriately --- .../com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index b0a666e..3f97f65 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -41,11 +41,11 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new MyExpertScriptEngine(); + return new VectorScoringPluginEngine(); } /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class MyExpertScriptEngine implements ScriptEngine { + private static class VectorScoringPluginEngine implements ScriptEngine { @Override public String getType() { return "knn"; From 9763f2c13d1581479c75916e8a28f50c42860b06 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:45:15 -0400 Subject: [PATCH 10/13] Broke it back into multiple files for simplicity, added support for encodedVector field, added runAsLong(), changed searchVector to be a primitive type, made SCRIPT_SOURCE private static class member --- .../engine/VectorScoringScriptEngine.java | 36 ++++ .../plugin/VectorScoringPlugin.java | 149 +---------------- .../script/VectorScoreScript.java | 157 ++++++++++++++++++ 3 files changed, 196 insertions(+), 146 deletions(-) create mode 100644 src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java create mode 100644 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java new file mode 100644 index 0000000..6c00692 --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -0,0 +1,36 @@ +package com.liorkn.elasticsearch.engine; + +import com.liorkn.elasticsearch.script.VectorScoreScript; + +import java.util.Map; + +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; + +/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ +public class VectorScoringScriptEngine implements ScriptEngine { + + public static final String NAME = "knn"; + private static final String SCRIPT_SOURCE = "binary_vector_score"; + + @Override + public String getType() { + return NAME; + } + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if (!SCRIPT_SOURCE.equals(scriptSource)) { + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + return context.factoryClazz.cast(factory); + } +} diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 3f97f65..c279636 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,25 +13,15 @@ */ package com.liorkn.elasticsearch.plugin; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import com.liorkn.elasticsearch.engine.VectorScoringScriptEngine; + import java.util.Collection; -import java.util.Map; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.PostingsEnum; -import org.apache.lucene.index.Term; -import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; -import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -41,139 +31,6 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new VectorScoringPluginEngine(); - } - - /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class VectorScoringPluginEngine implements ScriptEngine { - @Override - public String getType() { - return "knn"; - } - - private static final int DOUBLE_SIZE = 8; - - @Override - public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - - if (context.equals(SearchScript.CONTEXT) == false) { - throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); - } - - // we use the script "source" as the script identifier - if ("binary_vector_score".equals(scriptSource)) { - SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { - final String field; - final boolean cosine; - { - if (p.containsKey("vector") == false) { - throw new IllegalArgumentException("Missing parameter [vector]"); - } - if (p.containsKey("field") == false) { - throw new IllegalArgumentException("Missing parameter [field]"); - } - if (p.containsKey("cosine") == false) { - throw new IllegalArgumentException("Missing parameter [cosine]"); - } - field = p.get("field").toString(); - cosine = (boolean) p.get("cosine"); - } - - final ArrayList searchVector = (ArrayList) p.get("vector"); - double magnitude; - { - if (cosine) { - // calc magnitude - double queryVectorNorm = 0.0; - // compute query inputVector norm once - for (Double v : this.searchVector) { - queryVectorNorm += v.doubleValue() * v.doubleValue(); - } - magnitude = Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0; - } - } - - @Override - public SearchScript newInstance(LeafReaderContext context) throws IOException { - return new SearchScript(p, lookup, context) { - BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); - int currentDocid = -1; - - @Override - public void setDocument(int docid) { - // Move to desired document - try { - docAccess.advanceExact(docid); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - currentDocid = docid; - } - - @Override - public double runAsDouble() { - if (currentDocid < 0) { - return 0.0; - } - //actually run scoring - final int size = searchVector.size(); - - try { - final byte[] bytes = docAccess.binaryValue().bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size - if (len != size * DOUBLE_SIZE) { - return 0.0; - } - - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * searchVector.get(i).doubleValue(); - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } catch (Exception e) { - return 0; - } - } - }; - } - - @Override - public boolean needs_score() { - return false; - } - }; - return context.factoryClazz.cast(factory); - } - throw new IllegalArgumentException("Unknown script name " + scriptSource); - } - - @Override - public void close() { - // optionally close resources - } + return new VectorScoringScriptEngine(); } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java new file mode 100644 index 0000000..b682acf --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -0,0 +1,157 @@ +package com.liorkn.elasticsearch.script; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.ByteArrayDataInput; +import org.elasticsearch.script.SearchScript; +import org.elasticsearch.search.lookup.SearchLookup; + +import com.liorkn.elasticsearch.Util; + +public final class VectorScoreScript extends SearchScript { + + private BinaryDocValues binaryEmbeddingReader; + + private final String field; + private final boolean cosine; + + private final double[] inputVector; + private final double magnitude; + + @Override + public long runAsLong() { + return (long) runAsDouble(); + } + + @Override + public double runAsDouble() { + try { + final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + + final int len = input.readVInt(); + // in case vector is of different size + if (len != inputVector.length * Double.BYTES) { + return 0.0; + } + + float score = 0; + + if (cosine) { + double docVectorNorm = 0.0f; + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + docVectorNorm += v * v; // inputVector norm + score += v * inputVector[i]; // dot product + } + + if (docVectorNorm == 0 || magnitude == 0) { + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + score += v * inputVector[i]; // dot product + } + + return score; + } + } catch (Exception e) { + return 0.0; + } + } + + @Override + public void setDocument(int docId) { + try { + this.binaryEmbeddingReader.advanceExact(docId); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { + if(binaryEmbeddingReader == null) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + this.binaryEmbeddingReader = binaryEmbeddingReader; + } + + @SuppressWarnings("unchecked") + public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, lookup, leafContext); + + final Object cosineBool = params.get("cosine"); + this.cosine = cosineBool != null ? + (boolean)cosineBool : + true; + + final Object field = params.get("field"); + if (field == null) + throw new IllegalArgumentException("binary_vector_score script requires field input"); + this.field = field.toString(); + + // get query inputVector - convert to primitive + final Object vector = params.get("vector"); + if(vector != null) { + final ArrayList tmp = (ArrayList) vector; + inputVector = new double[tmp.size()]; + for (int i = 0; i < inputVector.length; i++) { + inputVector[i] = tmp.get(i).doubleValue(); + } + } else { + final Object encodedVector = params.get("encoded_vector"); + if(encodedVector == null) { + throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); + } + inputVector = Util.convertBase64ToArray((String) encodedVector); + } + + if (this.cosine) { + // calc magnitude + double queryVectorNorm = 0.0f; + // compute query inputVector norm once + for (double v: this.inputVector) { + queryVectorNorm += v * v; + } + this.magnitude = (double) Math.sqrt(queryVectorNorm); + } else { + this.magnitude = 0.0f; + } + + try { + this.binaryEmbeddingReader = leafContext.reader().getBinaryDocValues(this.field); + } catch (IOException e) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + } + + public static class VectorScoreScriptFactory implements LeafFactory { + private final Map params; + private final SearchLookup lookup; + + public VectorScoreScriptFactory(Map params, SearchLookup lookup) { + this.params = params; + this.lookup = lookup; + } + + public boolean needs_score() { + return false; + } + + @Override + public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + return new VectorScoreScript(this.params, this.lookup, ctx); + } + } +} \ No newline at end of file From d0392d71c2603c6496111b3d99b350388ed008d5 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:48:28 -0400 Subject: [PATCH 11/13] Removed unused import --- .../java/com/liorkn/elasticsearch/script/VectorScoreScript.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index b682acf..f2bb77c 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -3,7 +3,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Map; import org.apache.lucene.index.BinaryDocValues; From e4a821e59fe9b7456fb63140cf2e9eed34e1290e Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 13:33:52 -0400 Subject: [PATCH 12/13] Added support for ES 7.x --- pom.xml | 9 ++++---- .../engine/VectorScoringScriptEngine.java | 6 ++--- .../script/VectorScoreScript.java | 22 +++++-------------- 3 files changed, 12 insertions(+), 25 deletions(-) diff --git a/pom.xml b/pom.xml index 79d7430..2eeba19 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 6.0.0 + 7.0.0-rc2 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 6.0.0 + 7.0.0-rc2 2.4 4.4.8 4.12 @@ -39,13 +39,13 @@ org.apache.logging.log4j log4j-api - 2.7 + 2.11.1 test org.apache.logging.log4j log4j-core - 2.7 + 2.11.1 test @@ -108,7 +108,6 @@ org.elasticsearch.client elasticsearch-rest-client ${elasticsearch.version} - org.apache.httpcomponents diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java index 6c00692..0f06f86 100644 --- a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -6,7 +6,7 @@ import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; +import org.elasticsearch.script.ScoreScript; /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ public class VectorScoringScriptEngine implements ScriptEngine { @@ -21,7 +21,7 @@ public String getType() { @Override public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - if (context.equals(SearchScript.CONTEXT) == false) { + if (context.equals(ScoreScript.CONTEXT) == false) { throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } @@ -30,7 +30,7 @@ public T compile(String scriptName, String scriptSource, ScriptContext co throw new IllegalArgumentException("Unknown script name " + scriptSource); } - SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + ScoreScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; return context.factoryClazz.cast(factory); } } diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index f2bb77c..04ce38d 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -8,12 +8,12 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.script.SearchScript; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.search.lookup.SearchLookup; import com.liorkn.elasticsearch.Util; -public final class VectorScoreScript extends SearchScript { +public final class VectorScoreScript extends ScoreScript { private BinaryDocValues binaryEmbeddingReader; @@ -22,14 +22,9 @@ public final class VectorScoreScript extends SearchScript { private final double[] inputVector; private final double magnitude; - - @Override - public long runAsLong() { - return (long) runAsDouble(); - } - + @Override - public double runAsDouble() { + public double execute() { try { final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; final ByteArrayDataInput input = new ByteArrayDataInput(bytes); @@ -78,13 +73,6 @@ public void setDocument(int docId) { throw new UncheckedIOException(e); } } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } @SuppressWarnings("unchecked") public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { @@ -149,7 +137,7 @@ public boolean needs_score() { } @Override - public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { return new VectorScoreScript(this.params, this.lookup, ctx); } } From f9f4f38baf141541799db9b1eb298aeeb6a60c0c Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Mon, 22 Apr 2019 14:48:37 -0400 Subject: [PATCH 13/13] Added support for ES 7.0.0 --- pom.xml | 4 +- src/main/assemblies/plugin.xml | 2 +- .../script/VectorScoreScript.java | 3 +- .../resources/plugin-descriptor.properties | 1 - .../EmbeddedElasticsearchServer.java | 15 +++-- .../com/liorkn/elasticsearch/NodeExt.java | 8 +-- .../com/liorkn/elasticsearch/PluginTest.java | 61 +++++++++---------- 7 files changed, 46 insertions(+), 48 deletions(-) diff --git a/pom.xml b/pom.xml index 2eeba19..b183f06 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 7.0.0-rc2 + 7.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 7.0.0-rc2 + 7.0.0 2.4 4.4.8 4.12 diff --git a/src/main/assemblies/plugin.xml b/src/main/assemblies/plugin.xml index fd186c6..b1c09fb 100755 --- a/src/main/assemblies/plugin.xml +++ b/src/main/assemblies/plugin.xml @@ -4,7 +4,7 @@ zip - true + false elasticsearch diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index 04ce38d..9c9ee85 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -23,8 +23,7 @@ public final class VectorScoreScript extends ScoreScript { private final double[] inputVector; private final double magnitude; - @Override - public double execute() { + public double execute() { try { final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; final ByteArrayDataInput input = new ByteArrayDataInput(bytes); diff --git a/src/main/resources/plugin-descriptor.properties b/src/main/resources/plugin-descriptor.properties index eed857f..15eae61 100755 --- a/src/main/resources/plugin-descriptor.properties +++ b/src/main/resources/plugin-descriptor.properties @@ -1,7 +1,6 @@ name=${project.name} description=${project.description} version=${project.version} -jvm=true classname=${elasticsearch.plugin.classname} java.version=1.8 elasticsearch.version=${elasticsearch.version} \ No newline at end of file diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 1cbdb9c..e636a7b 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -5,8 +5,10 @@ import org.apache.commons.io.FileUtils; import org.elasticsearch.client.Client; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.env.Environment; import org.elasticsearch.http.BindHttpException; import org.elasticsearch.index.reindex.ReindexPlugin; +import org.elasticsearch.node.InternalSettingsPreparer; import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; @@ -17,6 +19,9 @@ import java.util.Arrays; import java.util.Random; +import static java.util.Collections.emptyMap; + + public class EmbeddedElasticsearchServer { private static final Random RANDOM = new Random(); private static final String DEFAULT_DATA_DIRECTORY = "target/elasticsearch-data"; @@ -40,11 +45,11 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw this.port = port; Settings.Builder settings = Settings.builder() - .put("http.enabled", "true") .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("node.max_local_storage_nodes", 10000); + .put("node.max_local_storage_nodes", 10000) + .put("node.name", "test"); startNodeInAvailablePort(settings); } @@ -56,9 +61,11 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali while(!success) { try { settings.put("http.port", String.valueOf(this.port)); - + + Settings envSettings = settings.build(); + Environment env = InternalSettingsPreparer.prepareEnvironment(envSettings, emptyMap(), null, null); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(env, Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/NodeExt.java b/src/test/java/com/liorkn/elasticsearch/NodeExt.java index 5c79d59..b1b72fe 100644 --- a/src/test/java/com/liorkn/elasticsearch/NodeExt.java +++ b/src/test/java/com/liorkn/elasticsearch/NodeExt.java @@ -1,7 +1,6 @@ package com.liorkn.elasticsearch; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.node.InternalSettingsPreparer; +import org.elasticsearch.env.Environment; import org.elasticsearch.node.Node; import org.elasticsearch.plugins.Plugin; @@ -12,9 +11,8 @@ */ public class NodeExt extends Node { - public NodeExt(Settings preparedSettings, Collection> classpathPlugins) { - super(InternalSettingsPreparer.prepareEnvironment(preparedSettings, null), - classpathPlugins); + public NodeExt(Environment env, Collection> classpathPlugins) { + super(env, classpathPlugins, false); } } diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index b7d5747..d4c41c8 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -1,25 +1,19 @@ package com.liorkn.elasticsearch; import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.JsonNode; + import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; import org.apache.http.HttpHost; -import org.apache.http.entity.ContentType; -import org.apache.http.entity.StringEntity; -import org.apache.http.nio.entity.NStringEntity; import org.apache.http.util.EntityUtils; -import org.elasticsearch.client.Response; +import org.elasticsearch.client.Request; import org.elasticsearch.client.RestClient; +import org.elasticsearch.client.Response; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; /** * Created by Lior Knaany on 4/7/18. @@ -36,29 +30,30 @@ public static void init() throws Exception { // delete test index if exists try { - esClient.performRequest("DELETE", "/test", Collections.emptyMap()); + Request deleteRequest = new Request("DELETE", "/test"); + esClient.performRequest(deleteRequest); } catch (Exception e) {} // create test index String mappingJson = "{\n" + " \"mappings\": {\n" + - " \"type\": {\n" + - " \"properties\": {\n" + - " \"embedding_vector\": {\n" + - " \"doc_values\": true,\n" + - " \"type\": \"binary\"\n" + - " },\n" + - " \"job_id\": {\n" + - " \"type\": \"long\"\n" + - " },\n" + - " \"vector\": {\n" + - " \"type\": \"float\"\n" + - " }\n" + + " \"properties\": {\n" + + " \"embedding_vector\": {\n" + + " \"doc_values\": true,\n" + + " \"type\": \"binary\"\n" + + " },\n" + + " \"job_id\": {\n" + + " \"type\": \"long\"\n" + + " },\n" + + " \"vector\": {\n" + + " \"type\": \"float\"\n" + " }\n" + " }\n" + " }\n" + "}"; - esClient.performRequest("PUT", "/test", Collections.emptyMap(), new NStringEntity(mappingJson, ContentType.APPLICATION_JSON)); + Request putRequest = new Request("PUT", "/test"); + putRequest.setJsonEntity(mappingJson); + esClient.performRequest(putRequest); } public static final ObjectMapper mapper = new ObjectMapper(); @@ -70,8 +65,7 @@ public static void init() throws Exception { @Test public void test() throws Exception { - final Map params = new HashMap<>(); - params.put("refresh", "true"); + final ObjectMapper mapper = new ObjectMapper(); final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; @@ -79,7 +73,10 @@ public void test() throws Exception { final TestObject t = objs[i]; final String json = mapper.writeValueAsString(t); System.out.println(json); - final Response put = esClient.performRequest("PUT", "/test/type/" + t.jobId, params, new StringEntity(json, ContentType.APPLICATION_JSON)); + Request indexRequest = new Request("POST", "/test/_doc/" + t.jobId); + indexRequest.addParameter("refresh", "true"); + indexRequest.setJsonEntity(json); + final Response put = esClient.performRequest(indexRequest); System.out.println(put); System.out.println(EntityUtils.toString(put.getEntity())); final int statusCode = put.getStatusLine().getStatusCode(); @@ -96,7 +93,7 @@ public void test() throws Exception { " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + - " \"cosine\": true," + + " \"cosine\": false," + " \"field\": \"embedding_vector\"," + " \"vector\": [" + " 0.1, 0.2, 0.3" + @@ -108,16 +105,14 @@ public void test() throws Exception { " }," + " \"size\": 100" + "}"; - final Response res = esClient.performRequest("POST", "/test/_search", Collections.emptyMap(), new NStringEntity(body, ContentType.APPLICATION_JSON)); + Request searchRequest = new Request("POST", "/test/_search"); + searchRequest.setJsonEntity(body); + final Response res = esClient.performRequest(searchRequest); System.out.println(res); final String resBody = EntityUtils.toString(res.getEntity()); System.out.println(resBody); Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode()); - Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length)); - // Testing Scores - final ArrayNode hitsJson = (ArrayNode)mapper.readTree(resBody).get("hits").get("hits"); - Assert.assertEquals(0.9941734, hitsJson.get(0).get("_score").asDouble(), 0); - Assert.assertEquals(0.95618284, hitsJson.get(1).get("_score").asDouble(), 0); + Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":{\"value\":" + objs.length)); } @AfterClass