Fraud Detection with JAVA and SPARK MLib

In this post we are going to develop the algorithm in JAVA using SPARK MLib. Full working code can be download from GitHub. It is possible to run the code with several different configurations and experiment on your own without deep Java knowledge(using configuration file).

On previous post we implemented the same anomaly detection algorithm using OCTAVE.We filtered out 500.000 records(only type TRANSFER) from 7 million to investigate and get an insight about the available data. Also several graphics were plotted to show how the data and anomalies(frauds) look like. Since OCTAVE loads all the data in memory it has limitations for large data. For this reason will use SPARK to run anomaly detection in a larger data set of 7 million.

Gaussian Distribution

On this section is given a brief description how Gaussian function is used for anomaly detection, for more detailed view please refer to previous post. Gaussian density function has a bell curved shape like below :

Regular data which are the majority of the data tend to be on the center of the bell curved shape and anomalies on the edge where the point on the graphs are more rare. In same time we can see that point on the edge have lower function values(or probability, less than 0.1) comparing with those on the center(close to 0.4).

Following this example we can say that every coming example that has a probability density function lower than 0.05 is anomaly. Of course we can control the threshold value deepening on our needs. Big values means more anomalies are flagged and probably most of them are not anomalies on the other hand small value means we may miss anomalies as the algorithm become more tolerant. There are several ways to calculate an optimal value and one of the ways is described in details on previous post.

Above example is an one dimensional example with data having only one feature. In reality we have data with a lot more feature or dimensions. Anyway in order to plot our data into graph we reduce dimension of data using “Principal Component Analysis(PCA)“ to two dimension(2D) or even three dimension(3D). Please find below an example with two dimension:

Notice how normal data tend to stay together in the middle on first and second circle and anomalies on the edges from the third circle .Circles on the graph show how Gaussian bell curved is distributed among data(normally will be 3D bell shaped but for simplicity is shown in 2D).

In both cases in order to place an example on certain position in the bell curved shaped graph we need to calculate two components : µ(mean) and σ2(variance). Once we have calculated mean and variance we can apply the formula explained here to the density probability for new coming example. If probability is lower than certain value we flag as anomaly otherwise as normal. Please find a detail exploitation on previous post.

Spark and MLib

In this section a brief description of both Spark and MLib is given. For more detailed explanation and tutorial please check out at official homepage .


Apache Spark is a cluster computing framework. Spark help us to execute jobs in parallel across different nodes in a cluster and than combine those results in one single result/response.It transforms our collection of data into collection of elements distributed across nodes of the cluster called RDD(resilient distributed dataset). For example in java program we can transform a collection into RDD capable for parallel operation like this:

JavaRDD<LabeledPoint> paralleledTestData = sc.parallelize(collection);

Parallel collections are cut into partitions  and SPARK executed on task per partition so we want to have 2-4 partitions per CPU.We can control the number of partitions Spark created by defining another argument to the method like sc.parallelize(collection,partitionNumber).  Beside collection cumming from application Spark is also capable to transform data from storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, Amazon S3.

After having our data transformed to RDD we can basically perform two kind of parallel operations on cluster nodes. Transforming operation which take as input a RDD collection and return a new RDD collection like map and actions which take a RDD and return a single results like reduce,count etc. Regardless of the type, actions are lazy similar to JAVA 8 in way that they do not run when defined but rather when requested. So we can have an operation calculated several time when requested and to avoid that we may persist in memory or cache.


SPARK supports APIs in Java, Scala, Python and R  also supports a rich set of higher-level tools including Spark SQL for SQL and structured data processing, MLlib for machine learning, GraphX for graph processing, and Spark Streaming.

MLlib is Spark’s machine learning (ML) library.  It provided several ready to use machine learning tools like;

  • ML Algorithms:
    • classification
    • regression
    • clustering
    • collaborative filtering
  • Featurization:
    • feature extraction
    • transformation
    • dimensionality reduction
    • selection

Utilities: linear algebra, statistics, data handling, etc.

Data Preparation

To get some insight of the data and how the anomalies are distributed across regular data please refer here. Similar to the previous post we need to prepare the data for the algorithm execution. Please find below an view how the date look like:

We need to convert everything into numbers.Fortunately most part of the data are numbers only nameOrig and nameDest  start with a character like C or D,M we simply replace C with 1 a D with 2 and M with 3. Also we convert types from chars to numbers like below:

  • PAYMENT  =>1
  • TRANSFER =>2
  • CASH_OUT =>3
  • DEBIT         =>4
  • CASH_IN    =>5

All the preparation is done on Java code using SPARK transformation operation map:

     File file = new File(algorithmConfiguration.getFileName());

        return sc.textFile(file.getPath()).
                map(line -> {
                    line = line.replace(, "1")
                            .replace(, "2")
                            .replace(, "3")
                            .replace(, "4")
                            .replace(, "5")
                            .replace("C", "1")
                            .replace("M", "2");
                    String[] split = line.split(",");
                    //skip header
                    if (split[0].equalsIgnoreCase("step")) {
                        return null;
                    double[] featureValues = Stream.of(split)
                            .mapToDouble(e -> Double.parseDouble(e)).toArray();
                    if (algorithmConfiguration.isMakeFeaturesMoreGaussian()) {
                    //always skip 9 and 10 because they are labels fraud or not fraud
                    double label = featureValues[9];
                    featureValues = Arrays.copyOfRange(featureValues, 0, 9);
                    return new LabeledPoint(label, Vectors.dense(featureValues));

After that the file should look like below:

Because of the big file size and GitHub file size limitation data are not provided withing the code so please download file from here rename to allData.csv (change constant FILE_NAME for different name) and copy inside folder data/.

Executing Algorithm

Lets see step by step how we can execute anomaly detection algorithm.

  1.  From all the data(7 million) we need to randomly choose a percentage for training, cross validation and test data(more about how they are used). The code that will randomly pick up regular and fraud data for a data set will look like below:
    Collections.shuffle(regularData);//randomly re order data
    List<LabeledPoint> regular =;
    List<LabeledPoint> fraud =;

    We run this code two times to get training and cross validation data and what is  left is test data. We will see several percentage choices later on.

  2. Next we will need µ(mean) and σ2(variance)calculation as they are crucial to get probability of new coming examples. The code will look like below:
    protected MultivariateStatisticalSummary getMultivariateSummary(GeneratedData<JavaRDD<LabeledPoint>> trainData) {
        return Statistics.colStats( -> e.features()).rdd());
  3. As mention earlier once we have mean and variance using Gaussian formula we can calculate the probability value. Based on the probability value than we decide if it is a anomaly or regular example. So we compare the value with some threshold(epsilon) if it is lower than we mark as anomaly if greater we mark as regular. Choosing epsilon is crucial as having small value can cause algorithm to flag a a lot of false frauds and on the other hand having big value we can miss frauds.We use Cross Validation data and precision and recall to choose best epsilon.
    Double bestEpsilon = findBestEpsilon(sc, crossData, summary);
  4. Now we are ready to evaluate our algorithm on test data(we also do an optional evaluation on cross validation data).
    TestResult testResultFromTestData = testAlgorithmWithData(sc, getTestData(crossData), summary, bestEpsilon);
    fillTestDataResults(resultsSummary, testResultFromTestData);
    TestResult testResultFromCrossData = testAlgorithmWithData(sc, crossData.regularAndAnomalyData, summary, bestEpsilon);
    fillCrossDataResults(resultsSummary, testResultFromCrossData);

Before executing the algorithm we need need to download data(is not packed because of GitHub file size limitation)extract and copy paste as allData.csv to folder data/allData.csv. The file location is configurable as well as the file name.Algorithm can be tested with the data and various options though the configuration file at : config/ like below:

#60% of regular data used for training

#0% of fraud data used for training

#50% of frauds used as test data
#20% of regular data used as test data

#50% of frauds used as cross data
#20% of regular data used as cross data

#We can skip 11 features indexed from 0 to 10 ex 1,2,6,7

#Possible values :

#Possible values SPARK and JAVA_STREAM

#How many times you want the algorithm to run

#make features more gaussian by powering current values



After configuration are changed application can be run by Running class : Run on java IDE or in maven by running:

mvn clean install exec:java

Depending on you machine and configuration it make take some time(for me it takes 2 minutes with ALL type) until application finish and also your computer may freeze a bit as SPARK gets the CPU 100% at certain point also expect a lot of memory used by application(2-3 GB for me). You can see result printed on console or by looking at folder out/ there will a generated file *.txt with output. As explained at previous post in more details algorithm is based on Randomness so you can configure to run several time and expect one file per each execution.

Experiments and Results

From my experiments it results that frauds are available only for two types: TRANSFER and CASH_OUT. Transfer was investigated in details in previous post and we achieved a pretty high rate 99.7%.

When run only for CASH_OUT type and do not skip any columns/features we get  poor results:

, RUN =0
, successPercentage=0.13532555879494654
, failPercentage=0.8646744412050534
, trainFraudSize=0
, trainTotalDataSize=0
, transactionTypes=[CASH_OUT]
, timeInMilliseconds=58914
, testNotFoundFraudSize=1076
, testFoundFraudSize=180
, testFlaggedAsFraud=4873
, testFraudSize=1256
, testRegularSize=446023
, testTotalDataSize=447279….

So we are able only to find approx 14% of frauds for this type. Previous time we were able to improve a lot by making feature more look like Gaussian bell shape but unfortunately this time is not the case as they already are.

What we can is to look at our feature and see of we can add or maybe skip some feature from data since sometimes features introduce confusion and noise rather than benefit. Looking at the source of data we have following description of fraud which can help:

isFraud – This is the transactions made by the fraudulent agents inside the simulation. In this specific dataset the fraudulent behavior of the agents aims to profit by taking control of customers accounts and try to empty the funds by transferring to another account and then cashing out of the system.

So it looks like fraud here is consider when is cashed out from some source account and probably when big or all is cashed out from account is considered fraud. So slowly we start removing not needed features and I found good results by removing features [1,2,3,7,8] or type(we have only one so makes no sense)amount,nameOrig,oldBalanceDest,newBalanceDest. When cashing out probably the destination is not that important but rather the account that the money is being taken. Is not that important because the account may have already money and look pretty normal but on the other hand a empty source account may signal a fraudulent behavior.We leave the destination account name as it may help in case of fraudulent account names. The results look like :

Finish within 70027
, RUN =0
, successPercentage=0.8277453838678328
, failPercentage=0.17225461613216717
, trainFraudSize=0
, trainTotalDataSize=0
, transactionTypes=[CASH_OUT]
, timeInMilliseconds=67386
, testNotFoundFraudSize=218
, testFoundFraudSize=1016
, testFlaggedAsFraud=139467
, testFraudSize=1234
, testRegularSize=446808
, testTotalDataSize=448042

So basically is huge improvement as we were able to go from 14% to 82,77%. Running ALL types together also it does not bring any better results even with different skipped features(feel free to try as not all is explored). I was able to get some results with skipping only amount(2) but still is not satisfactory as a lot of non frauds were flagged(1040950).

Finish within 128117
, RUN =0
, successPercentage=0.8700840131498844
, failPercentage=0.12991598685011568
, trainFraudSize=0
, trainTotalDataSize=0
, transactionTypes=[ALL]
, timeInMilliseconds=125252
, testNotFoundFraudSize=325
, testFoundFraudSize=2153
, testFlaggedAsFraud=1040950
, testFraudSize=2478
, testRegularSize=1272665
, testTotalDataSize=1275143

So probably in this case is better to run algorithm for each type. When a possible transaction is made we run against its type. In this way we will be able to detect more appropriately as TRANSFER has 99.7% rate and CASH_OUT 87% percent. Still for CASH_OUT we can say that the rate is not that satisfactory and maybe other approaches may worth try it like more data but this has to be investigated first(usually intuition is wrong and cost a lot of time). Since more date in finance application is very difficult to get because of privacy I would rather go on the direction of applying different algorithms here. When the data for CASH_OUT were plot we got a view like below:

With red are normal data, with magenta are not found frauds and with green are found frauds, blue wrongly flagged as fraud.  This graphs shows that the problem is that majority of frauds are contained in the center of the normal data and algorithm is struggling to detect them even if I believe yet there could be other ways to mixture features of even add more which can greatly help.


We can configure algorithm(please see property runsWith) to run on SPARK or JAVA 8 Streams for manipulating the data. Spark is a great framework if you want to run your code on several remote nodes on cluster and aggregate results to the requested machine. But on this post the algorithm is executed locally and SPARK treats local resource like number of CPU as target clusters resources. On the other hand JAVA 8 streams easily provide paralleling by using collection.stram().parallel()  of course on the running machine locally. So as part of the experiment JAVA 8 streams were compared to SPARK on single machine.

Results show that JAVA 8 streams are faster locally even if not with a big factor. JAVA 111927 seconds VS SPARK 128117 seconds, so basically 16 -25 seconds faster streams when run with ALL.Please note at your computer results may differ, feel free to suggest new results.

Since SPARK is optimized for distributed computing is understandable that it has some overhead like partitioning ,task and so on in comparison with JAVA Streams which need to think only for local machine and have the luxury to optimize a lot there. Anyway I can see the gab closing with amount of data increasing even locally.

For small amount of data JAVA 8 Stream fit better but for huge amount of data SPARK scales and fits better. Maybe is worth to try SPARK not locally but configured on cluster running maybe on amazon web services. Form more details please find  into the code two JAVA implementations handling the same exact algorithm but with non essentials small differences: FraudDetectionAlgorithmJavaStream and FraudDetectionAlgorithmSpark.

Found useful , feel free to share