Email Spam Classifier Java Application with SPARK

In this post we are going to develop an application for the purpose of detecting spam emails.The algorithm which will be used is Logistic Regression , implementation from SPARK MLib. No deep knowledge on the field is required as the topics are described from a high level perspective as possible. Full working code is provided together with a running application for further experiments on your choice of emails(please last section).

Logistic Regression

Logistic Regression is an algorithm used for classification problems. In Classification problems we are given a lot of labeled data(example spam and not spam) and when a new example is coming we want to know which category it belongs to. Since it is a Machine Learning algorithm Logistic Regression is trained from labeled data and based on the training it gives is prediction about new coming examples.

Applications

In general when a lot of data are available and is needed to detect in which category an example belongs to we can say that Logistic Regression ca be used even if not always the results can be satisfactory.

Health Care

Analyzing million of patients health condition to predict if a patient will have myocardial infarction based. Or the same logic can be applied if a patient will have particular cancer or based on the words it says if it will be affected by depression and so on. In this application we have considerable amount of data and so logistic regression usually gives good hints.

Image Categorization

Based on image density colors we can categorize if lets say an image contains a human or not, or contains a car or not. Also since is a categorizing problem we may also use logistic regression to detect if a picture has characters or even to detect hand writing.

Message and Email Spam Classification

Probably on the most common application of the logistic regression is message or email spam classification. In this application we have the algorithm determines if an incoming email or message is spam or not. When a non personalized algorithm is build a lot of data are needed. Personalized filters usually perform better because the spam classifier depends at some certain degree to the persons interest and background.

How it works

So we have a lot of labeled examples and want to train our algorithm to be smart enough to say if new examples are part of one or other category. For simplification we are going to refer first to only binary classification(1 or 0) the algorithm scales easy also to multi classification.

Insight

Usually we have multidimensional data or data with many features. And each of this features is somehow contributing to final decision of which category an new examples belongs to. For example in a cancer classification problem we can have features like: age, smoking or not, weight , height, family genome and so onEach of this features contribute in a way to the final decision of choosing one of the categories. Features do not contribute equally but rather have different impacts in the determining the final state. For example we can say that definitely weight has lower impact than family genome in cancer prediction. In Logistic Regression that is exactly what we are trying to find out the weights/impact of the features of our data. Once we have a lot of data examples we can determine wight for each feature and when new examples come we just use the weights to see how this examples are categorized. In the cancer prediction example we can write like below:

More formally:
n-> number of examples
k-> number of features
θj   -> weight for feature j
Xj-> the i-th example X with feature j

Model Representation

In order to represent or data into categories we need a function(hypothesis) which is able based on examples values and features to put them in one of the two categories.The function exists and is called Sigmoid Function and graphically it looks like below:

As we can see when values on X axis are positive the sigmoid function values tends to go to 1 and when values on X axis are negative we can see that sigmoid function tends to go to 0. So basically we have a model to represent two categories and mathematically the function looks like below:

Z is the function explained in above topic Insight.

In order to get discrete values 1 or 0 we can say that when a function value(Y axis) is greater than 0.5 we classify as 1 and when a function value(Y axis) is smaller than 0.5 we classify as 0 as below:

Y > 0.5 => 1 (spam ,cancer)
Y < 0.5 => 0 (not spam, not cancer)

and

Z > 0 => 1 (spam ,cancer)
Z < 0 => 0 (not spam, not cancer)

Cost function

We don’t want to find ‘some wights’ but rather the best weights we can have with the actual data. So in order to find the best weights we need another function that calculates how good a solution is for particular weights we found. After having this function we can compare different solutions with different weights and find the best one. This function is called cost function and what it does in principle is just comparing hypothesis(sigmoid) function value with the real data value. Since the data we use for training are labeled(spam or not spam) we compare hypothesis(sigmoid) prediction with the actual value which we know for sure. We want that the difference between hypothesis and real value to be as small as possible , ideally we want the cost function to be zero. More formally the cost function is defined as:

where yi is the real value or category like spam or not spam 1 or 0 and h(x) is the hypothesis 

So basically what equation above does is just calculates how well(in average) our prediction is in comparing to real labeled data(y). Now because we have two cases 1 and 0 we have two H(hypothesis) :  h1 and h0 respectively. (We are going to apply log to hypothesis because in this way the function becomes convex and is safer to find the global minimum)

Lets look at  h1 : hypothesis in relation to cost function for 1 category

So what we did is applied log to our hypothesis instead of using it directly. The reason we did that is because we want to achieve a relation such when hypothesis is close to one the cost function goes to zero. Remember that we want our cost function to be ideally zero so there is no difference between hypothesis prediction and labeled data. So if the hypothesis is going to predict 0 our cost function grows large so we know this is not an example belonging to 1 category and vice versa if hypothesis is going to predict 1 cost function goes to zero signaling that the example belongs to 1 category.

Lets look at  h0 : hypothesis in relation to cost function for 0 category

In this case we applied again log but in a way to have cost function go to zero when hypothesis is going also to predict zero. So if the hypothesis is going to predict 1 our cost function grows large so we know this is not an example belonging to 0 category and vice versa if hypothesis is going to predict zero cost function goes to zero signaling that the example belongs to 0 category.

So now we have two cost functions and what we need is just to combine them in one.After that the equation becomes a bit messy but in principle is just a merge of two cost function we explained above:

Notice that first term is just cost function for h1 and second term const function for h0. So if y=1 than second term is eliminated and if y=0 first term is eliminated.  

Minimize Cost Function

As we saw above we want our cost function to be ideally zero so our prediction will be as close as possible to the real value(labeled). Fortunately there is already build in algorithm to minimize the cost function Gradient Descent. Once we have the cost function which basically compares our hypothesis to real values we can change our wights(θ) in order lower the cost function as much as possible. First we pick up random values of θ just to have some values,than calculate cost function. Depending on results we can lower our θ values or increase so the cost function is optimize to zero. We repeat this procedure until the cost function is almost zero(0.0001) or is not improving much iteration to iteration.

Gradient descent is doing exactly this in principle, just it uses derivative of cost function to decide if to lower or increase θ values. Beside the derivative which is just giving a direction to lower or to increase the value it also uses a coefficient α to define how much to change the θ values. Changing θ values to much(big α) can make gradient descent fail optimizing cost function to zero, since a big increase may overcome the real value and also a big decrease may go far from wanted value. While having small change of θ(small α) means we are safe, but the algorithm needs a lot of time to go to the minimum value of cost function(almost zero) since we are progressing too slow towards the wanted or real value(for more visual explanation please look here). More formally we have:

The term on the right is the derivative of cost function(changes only by multiply with X for feature k). Since our data are multidimensional (k features)we do this for each feature weights θk. 

Algorithm Execution

Prepare Data

Before going to execute the data we need to do some data prepossessing in order to clean not useful information. The main idea for the data reprocessing was taken by Coursera 6 assignment. We do the following:

  • Lower-casing: The entire email is converted into lower case, so
    that captialization is ignored (e.g., IndIcaTE is treated the same as
    Indicate).
  • Stripping HTML: All HTML tags are removed from the emails.
    Many emails often come with HTML formatting; we remove all the
    HTML tags, so that only the content remains.
  • Normalizing URLs: All URLs are replaced with the text “XURLX”.
  •  Normalizing Email Addresses: All email addresses are replaced
    with the text “XEMAILX”.
  • Normalizing Numbers: All numbers are replaced with the text
    “XNUMBERX”.
  • Normalizing Dollars: All dollar signs ($) are replaced with the text
    “XMONEYX”.
  • Word Stemming: Words are reduced to their stemmed form. For example,
    “discount”, “discounts”, “discounted” and “discounting” are all
    replaced with “discount”. Sometimes, the Stemmer actually strips o↵
    additional characters from the end, so “include”, “includes”, “included”,
    and “including” are all replaced with “includ”.
  • Removal of non-words: Non-words and punctuation have been removed.
    All white spaces (tabs, newlines, spaces) have all been trimmed
    to a single space character.

The code implementation will look like below:

private List<String> filesToWords(String fileName) throws Exception {
    URI uri = this.getClass().getResource("/" + fileName).toURI();
    Path start = getPath(uri);
    List<String> collect = Files.walk(start).parallel()
            .filter(Files::isRegularFile)
            .flatMap(file -> {
                try {

                    return Stream.of(new String(Files.readAllBytes(file)).toLowerCase());
                } catch (IOException e) {
                    e.printStackTrace();
                }
                return null;
            }).collect(Collectors.toList());

    
    return collect.stream().parallel().flatMap(e -> tokenizeIntoWords(prepareEmail(e)).stream()).collect(Collectors.toList());
}
private String prepareEmail(String email) {
    int beginIndex = email.indexOf("\n\n");
    String withoutHeader = email;
    if (beginIndex > 0) {
        withoutHeader = email.substring(beginIndex, email.length());
    }
    String tagsRemoved = withoutHeader.replaceAll("<[^<>]+>", "");
    String numberedReplaced = tagsRemoved.replaceAll("[0-9]+", "XNUMBERX ");
    String urlReplaced = numberedReplaced.replaceAll("(http|https)://[^\\s]*", "XURLX ");
    String emailReplaced = urlReplaced.replaceAll("[^\\s]+@[^\\s]+", "XEMAILX ");
    String dollarReplaced = emailReplaced.replaceAll("[$]+", "XMONEYX ");
    return dollarReplaced;
}

private List<String> tokenizeIntoWords(String dollarReplaced) {
    String delim = "[' @$/#.-:&*+=[]?!(){},''\\\">_<;%'\t\n\r\f";
    StringTokenizer stringTokenizer = new StringTokenizer(dollarReplaced, delim);
    List<String> wordsList = new ArrayList<>();
    while (stringTokenizer.hasMoreElements()) {
        String word = (String) stringTokenizer.nextElement();
        String nonAlphaNumericRemoved = word.replaceAll("[^a-zA-Z0-9]", "");
        PorterStemmer stemmer = new PorterStemmer();
        stemmer.setCurrent(nonAlphaNumericRemoved);
        stemmer.stem();
        String stemmed = stemmer.getCurrent();
        wordsList.add(stemmed);
    }
    return wordsList;
}

Transform Data

Once the emails are prepared we need to transform the data into a structure that the algorithm understand like matrices and features.

First step is to build a ‘Spam Vocabulary‘ by reading all spam emails words and count them. For example we count how many times ‘transaction‘,’XMONEYX‘,’finance‘,’win’,’free’ … are used. Than pick up first 10.000(featureSize) most frequent words(by counting their occurrences). At this point we have a map of size 10.000(featureSize) in which the key is the word and the value the index from 0 to 9.999. This will serve like a reference for possible spam words.Like below:

public Map<String, Integer> createVocabulary() throws Exception {
    String first = "allInOneSpamBase/spam";
    String second = "allInOneSpamBase/spam_2";
    List<String> collect1 = filesToWords(first);
    List<String> collect2 = filesToWords(second);

    ArrayList<String> all = new ArrayList<>(collect1);
    all.addAll(collect2);
    HashMap<String, Integer> countWords = countWords(all);

    List<Map.Entry<String, Integer>> sortedVocabulary = countWords.entrySet().stream().parallel().sorted((o1, o2) -> o2.getValue().compareTo(o1.getValue())).collect(Collectors.toList());
    final int[] index = {0};
    return sortedVocabulary.stream().limit(featureSIze).collect(Collectors.toMap(e -> e.getKey(), e -> index[0]++));
}
HashMap<String, Integer> countWords(List<String> all) {
    HashMap<String, Integer> countWords = new HashMap<>();
    for (String s : all) {
        if (countWords.get(s) == null) {
            countWords.put(s, 1);
        } else {
            countWords.put(s, countWords.get(s) + 1);
        }
    }
    return countWords;
}

Next step is do go on each email(spam and not spam) on our data and count the words frequency for each of them. Than for each for those words we look up in the spam vocabulary(10.000 size map we created before) to see if it is there. In the case it is there(meaning has a possible spam word) we put this word in the same index contained in ‘spam vocabulary map ‘and as value we put his frequency. In the end we build a matrix Nx10.000 where N is the number of emails considered and 10.000 a vector containing the frequency of ‘spam vocabulary map words‘ in the emails(if a spam word is not found in email we put 0).

For example(taken from): Lets say we have a ‘spam vocabulary’ like below:

1 aa
2 how
3 abil
4 anyon
5 know
6 zero
7 zip

and also an email like below in prepossessed form :

anyon know how much it cost to host a web portal well it depend on how
mani visitor your expect thi can be anywher from less than number buck
a month to a coupl of dollarnumb you should checkout XURLX or perhap
amazon ecnumb if your run someth big to unsubscrib yourself from thi
mail list send an email to XEMAILX

After the transformation we will have :

0 2 0 1 1 1 0 0 ==> So we have 0 aa , how, 0 abil, 1 anyon, 1 know, 0 zero, 0 zip.This is a 1×7 matrix since we had one email and spam vocabulary of 7 words. The code looks like below:

private Vector transformToFeatureVector(Email email, Map<String, Integer> vocabulary) {
    List<String> words = email.getWords();
    HashMap<String, Integer> countWords = prepareData.countWords(words);
    double[] features = new double[featureSIze];//featureSIze==10.000
    for (Map.Entry<String, Integer> word : countWords.entrySet()) {
        Integer index = vocabulary.get(word.getKey());//see if it is in //spam vocabulary 
        if (index != null) {
//put frequency the same index as the vocabulary
            features[index] = word.getValue();
        }
    }
    return Vectors.dense(features);
}

Execute and Results

Application can be downloaded and executed without any knowledge of java beside JAVA has to be installed on your computer. Feel free to test the algorithm with your own emails.

We can run the application from source by simply executing the RUN class or if you do not fill to open it with IDE just run mvn clean install exec:java.

After that you should be able to see something like this:

Please train first the algorithm by clicking train with LR SGD or LR LBFGS(it may take 1-2 minutes), after finishing a pop up will tell the precision achieved. No worries about SGD or LBFGS they are just different ways of minimizing the cost function, they will give almost same results(more here). After that please copy paste an email of your choice on the white area and hit test. After that a pop window will tell the algorithm prediction.

The precision achieved during my execution was approximately 97% using random 80% of the date for training and 20% for testing. No cross validation test just training and test(for accuracy measure) set was used in this example for more about dividing the data please refer here.

The code for training the algorithm is fairly simple:

public MulticlassMetrics execute() throws Exception {
    vocabulary = prepareData.createVocabulary();
    List<LabeledPoint> labeledPoints = convertToLabelPoints();
    sparkContext = createSparkContext();
    JavaRDD<LabeledPoint> labeledPointJavaRDD = sparkContext.parallelize(labeledPoints);
    JavaRDD<LabeledPoint>[] splits = labeledPointJavaRDD.randomSplit(new double[]{0.8, 0.2}, 11L);
    JavaRDD<LabeledPoint> training = splits[0].cache();
    JavaRDD<LabeledPoint> test = splits[1];


    linearModel = model.run(training.rdd());//training with 80% data

//testing with 20% data
    JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
            (Function<LabeledPoint, Tuple2<Object, Object>>) p -> {
                Double prediction = linearModel.predict(p.features());
                return new Tuple2<>(prediction, p.label());
            }
    );

    return new MulticlassMetrics(predictionAndLabels.rdd());
}
error

Enjoy this blog? Please spread the word :)