factory design pattern in scala with real world example in scala

Lets say we have a data comparison use case and we need to compare the raw data with the aggregated data based on the deviation mode that will be configured in the database and also we need to have separate algorithm based on the duration mode that will be set in the database.

Prerequisite

Please refer the earlier post on how we have implemented this in java for more information and understanding here factory-pattern-real-world-example-java and its a prerequisite to understand the below article.

There are two ways to implement a factory pattern in scala .First approach is to use the def apply(mode: String, stream: String): Mode ={} and the second one is a normal method like def getDeviationMode(mode: String, stream: String):Mode ={}.

Below is the full code .By following the below approach we also make sure that only my factory could produce concrete instances of DownStreamHigherMode, DownStreamLowerMode, UpstreamHigherMode and UpstreamLowerMode. We also define a triat and we will extend the trait Mode with all the objects of factory to make sure we have common type to return.


trait Mode {

def compareMeanDeviation(meanDevaition: Double, threshold: Double): Boolean =
{
if (meanDevaition > threshold) true else false
}

def meanDeviation(rawMetricValue: Double, averageMetricValue: Double): Double;

}


object DeviationMode {

private class DownStreamHigherMode extends Mode {

override def meanDeviation(rawMetricValue: Double, averageMetricValue: Double): Double =
{
if (rawMetricValue > averageMetricValue && averageMetricValue > 22) {
((rawMetricValue - averageMetricValue) / averageMetricValue) * 100;

} else {
0
}

}

}

private class DownStreamLowerMode extends Mode {

override def meanDeviation(rawMetricValue: Double, averageMetricValue: Double): Double =
{
if (averageMetricValue > rawMetricValue && averageMetricValue > 40) {
((averageMetricValue - rawMetricValue) / averageMetricValue) * 100

} else {
0
}}}

private class UpstreamHigherMode extends Mode {

override def meanDeviation(rawMetricValue: Double, averageMetricValue: Double): Double =
{
if (rawMetricValue > averageMetricValue && averageMetricValue > 32) {
((rawMetricValue - averageMetricValue) / averageMetricValue) * 100;

} else {
0
}

}

}

private class UpstreamLowerMode extends Mode {

override def meanDeviation(rawMetricValue: Double, averageMetricValue: Double): Double =
{
if (averageMetricValue > rawMetricValue && averageMetricValue > 15) {
((averageMetricValue - rawMetricValue) / averageMetricValue) * 100

} else {
0
}}}

// First approach

def apply(mode: String, stream: String): Mode =
{
if (stream == "upstream" && mode == "LOWER") {
return new UpstreamLowerMode
} else if (stream == "upstream" && mode == "HIGHER") {
return new UpstreamHigherMode

} else if (stream == "downstream" && mode == "LOWER") {

return new DownStreamLowerMode

} else if (stream == "downstream" && mode == "HIGHER") {

return new DownStreamHigherMode
} else {
throw new IllegalArgumentException("Invalid argument.")
}

}

//Second approach

def getDeviationMode(mode: String, stream: String):Mode =
{

if (stream == "upstream" && mode == "LOWER") {
return new UpstreamLowerMode
} else if (stream == "upstream" && mode == "HIGHER") {
return new UpstreamHigherMode

} else if (stream == "downstream" && mode == "LOWER") {

return new DownStreamLowerMode

} else if (stream == "downstream" && mode == "HIGHER") {

return new DownStreamHigherMode
} else {
throw new IllegalArgumentException("Invalid argument.")
}
}

}

The driver code to test the above code


object Driver {

def main(args: Array[String]): Unit = {

val deviationMode: String = "LOWER";

//First Approach

val mode: Mode = DeviationMode(deviationMode, "upstream")

//Second Approach

val mode2:Mode = DeviationMode.getDeviationMode(deviationMode, "upstream")

val mean: Double = mode.meanDeviation(10, 60)
val breach: Boolean = mode.compareMeanDeviation(mean, 60)

val mean2: Double = mode2.meanDeviation(10, 60)
val breach2: Boolean = mode2.compareMeanDeviation(mean2, 60)

if (breach) {
print("breach generated")
} else {
print("no breach generated")
}}}