Skip to content

Commit e741143

Browse files
committed
added automatic topic labelling
1 parent 9f04d36 commit e741143

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ Now, we need to create a config file for Feed Visualizer. The config file contai
5151
"pretrained_model": "all-mpnet-base-v2",
5252
"clust_dist_threshold": 4,
5353
"tsne_iter": 8000,
54-
"text_max_length": 2048
54+
"text_max_length": 2048,
55+
"topic_str_min_df": 0.25
5556
}
5657
```
5758

@@ -77,7 +78,8 @@ Here is some information on what each config setting does:
7778
"pretrained_model": "name of pretrained model. Here is list of all valid model names https://www.sbert.net/docs/pretrained_models.html#model-overview",
7879
"clust_dist_threshold": "Integer representing maximum radius of cluster. There is no correct value here. Experiment !",
7980
"tsne_iter": "Integer representing number of iterations for TSNE (higher is better)",
80-
"text_max_length": "Integer representing number of characters to read from content/description for semantic encoding."
81+
"text_max_length": "Integer representing number of characters to read from content/description for semantic encoding.",
82+
"topic_str_min_df": "A float. For example value of 0.25 means that only phrases which are present in 25% or more items in a cluster will be considered for being used as name of the cluster."
8183
}
8284
```
8385

config.json

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
{
2-
"input_directory": "feeds",
2+
"input_directory": "nasa",
33
"output_directory": "feeds_output",
44
"pretrained_model": "all-mpnet-base-v2",
5-
"clust_dist_threshold":1,
5+
"clust_dist_threshold":0.5,
66
"tsne_iter": 8000,
77
"text_max_length": 8048,
8-
"random_state": 45
8+
"random_state": 45,
9+
"topic_str_min_df": 0.25
910
}

visualization.html

+8-6
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@
5757
function makeplot() {
5858
d3.csv("data.csv" + '?' + Math.floor(Math.random() * 1000)).then((d) => {
5959
csv_data = d
60-
let clusterNumbers = d.map(a => parseInt(a.cluster))
60+
let clusterNumbers = []
61+
let topics = {}
62+
d.forEach(a => {topics[a.cluster] = a.topic;clusterNumbers.push(parseInt(a.cluster))})
6163
cluster_count = Math.max(...clusterNumbers) + 1
6264
d3.select('#clusters')
6365
.selectAll('span')
@@ -69,17 +71,17 @@
6971
.style("border", "1px solid grey")
7072
.style("min-width", "25px")
7173
.style("display", "inline-block")
72-
.style("color", "white")
73-
.style("text-shadow", "1px 1px grey")
74+
.style("color", function (d) { return (d < (cluster_count*.3) || d > (cluster_count*.7))? 'white':'black'})
75+
//.style("text-shadow", "1px 1px grey")
7476
.style("margin", "1px")
7577
.style("border-radius", "2px")
7678
.attr("data-clusterId", function (d) { return d })
7779
.on("mouseover", function (e, d) {
7880
//console.log(this.data.cluserId)
7981
//let currentClusterId = this.getAttribute("data-clusterId")
8082
let newTrace = JSON.parse(JSON.stringify(trace1));
81-
let new_colors = newTrace.marker.color.map(function (c,idx) {
82-
return csv_data[idx].cluster == d ? color(d / cluster_count) : "#e0eeeeee"
83+
let new_colors = newTrace.marker.color.map(function (c, idx) {
84+
return csv_data[idx].cluster == d ? color(d / cluster_count) : "#e0eeeeee"
8385
})
8486
newTrace.marker.color = new_colors
8587
drawPlot('myDiv', [newTrace], layout)
@@ -89,7 +91,7 @@
8991
drawPlot('myDiv', [trace1], layout)
9092
})
9193
.text(function (d) {
92-
return d;
94+
return d + " - " + topics[d];
9395
});
9496

9597
d.forEach(element => {

visualize.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from bs4 import BeautifulSoup, SoupStrainer
1414
from sentence_transformers import SentenceTransformer
1515
from sklearn.cluster import AgglomerativeClustering
16+
from sklearn.feature_extraction.text import CountVectorizer
1617
from sklearn.manifold import TSNE
1718
from tqdm import tqdm
1819

@@ -82,6 +83,22 @@ def get_coordinates(entries):
8283
clusters = clustering_model.fit_predict(tsne_output)
8384
return [x[0] for x in tsne.fit_transform(X)], [x[1] for x in tsne.fit_transform(X)], clusters
8485

86+
def find_topics(df):
87+
topics = []
88+
for i in range(0,df["cluster"].max()+1):
89+
try:
90+
df_text = df[df['cluster']==i]["label"]
91+
vectorizer = CountVectorizer(ngram_range=(1,2),min_df=config["topic_str_min_df"],stop_words='english')
92+
X = vectorizer.fit_transform(df_text)
93+
possible_topics = vectorizer.get_feature_names_out()
94+
idx_topic = np.argmax([len(a) for a in possible_topics])
95+
topics.append(possible_topics[idx_topic])
96+
#x,y = np.argmax(np.max(X, axis=1)),np.argmax(np.max(X, axis=0))
97+
#topics.append(vectorizer.get_feature_names_out()[y])
98+
except:
99+
topics.append("NA")
100+
pass
101+
return topics
85102

86103
def main():
87104
all_entries = get_all_entries(config["input_directory"])
@@ -106,7 +123,9 @@ def main():
106123
df = pd.DataFrame({'x': x, 'y': y, 'label': labels,
107124
'count': counts, 'url': entries.keys(), 'cluster': cluster_info})
108125

109-
126+
topics = find_topics(df)
127+
df["topic"] = df["cluster"].apply(lambda x : topics[x])
128+
print('Assigning cluster names !')
110129
if not os.path.exists(config["output_directory"]):
111130
os.makedirs(config["output_directory"])
112131
df.to_csv(config["output_directory"]+"/data.csv")

0 commit comments

Comments
 (0)