From b59299d62b1196e9a93c5a354acfd2000ff0fb56 Mon Sep 17 00:00:00 2001
From: Edward Li <LegendEddie18@gmail.com>
Date: Sat, 27 Mar 2021 21:30:29 -0400
Subject: [PATCH] Make methods require dictionary

---
 nlp.py | 24 +++++++++++++++---------
 1 file changed, 15 insertions(+), 9 deletions(-)

diff --git a/nlp.py b/nlp.py
index 3467239..13dc023 100644
--- a/nlp.py
+++ b/nlp.py
@@ -13,7 +13,7 @@ newsgroups_test = fetch_20newsgroups(subset='test')
 
 np.random.seed(400)
 stemmer = SnowballStemmer("english")
-NUM_TOPICS = 7
+NUM_TOPICS = 10
 
 
 def lemmatize_stemming(text):
@@ -30,13 +30,13 @@ def preprocess(text):
     return result
 
 
-def categorize_str(s: str, lda_model) -> int:
+def categorize_str(s: str, lda_model, dictionary) -> int:
     """
     Takes in a string to determine which topic it belongs to
     Returns the topic number as an int
     """
     processed_doc = preprocess(s)
-    dictionary = gensim.corpora.Dictionary([processed_doc])
+    # dictionary = gensim.corpora.Dictionary([processed_doc])
     bow_vector = dictionary.doc2bow(preprocess(s))
     ldaResults = sorted(lda_model[bow_vector], key=lambda tup: -1*tup[1])
     return ldaResults[0][0]
@@ -57,20 +57,26 @@ def create_model(documents: list):
                                            id2word=dictionary,
                                            passes=10,
                                            workers=2)
-    return lda_model
+    return (lda_model, dictionary)
 
 
-def update_model(s: str, lda_model):
+def update_model(s: str, lda_model, dictionary):
     """
     Takes in a string to update model
     Trains model using string
     """
     processed_doc = preprocess(s)
-    dictionary = gensim.corpora.Dictionary([processed_doc])
+    # dictionary = gensim.corpora.Dictionary([processed_doc])
+    dictionary.add_documents([processed_doc])
     bow_corpus = [dictionary.doc2bow(processed_doc)]
     lda_model.update(bow_corpus)
 
 
-# lda_model = create_model(newsgroups_train.data)
-# update_model("Hello everyone", lda_model)
-# print(categorize_str("Hello world", lda_model))
+# lda_model, dictionary = create_model(newsgroups_train.data)
+# print(dictionary.num_docs)
+# print(categorize_str("finance", lda_model, dictionary))
+# print(categorize_str("football", lda_model, dictionary))
+# print(categorize_str("virus", lda_model, dictionary))
+# print(categorize_str("economy", lda_model, dictionary))
+# update_model("Hello everyone", lda_model, dictionary)
+# print(categorize_str("Hello world", lda_model, dictionary))