Why Scala is worth a try

In the last few years several interesting languages appeared for the Java Virtual Machine. I think the first notable one was Nice which has a very similar but less verbose syntax as Java and provides concepts like Multimethods or Design by Contract. One of the most famous alternative JVM languages is probably Groovy, a dynamic scripting language which can be compared with Python or Ruby. The idea behind Groovy is to extend the original Java syntax to make coding much faster but still stay compatible with Java.
A totally new approach to develop a language for the JVM is Clojure, a Lisp clone optimized for concurrency and compatibility with traditional Java libraries. Clojure is a dynamic, functional language, can be interpreted or directly compiled into bytecode and has the most terrifying syntax someone can imagine who never came in contact with (L(i(s(p)))) before.

All these languages have their own particular niche, except maybe Clojure which could be a good Java replacement for parenthesis lovers. But what happens with the rest of us? I personally like the idea of having a programming language which allows me to use the wide range of Java libraries without writing tons of code to do things, a modern script language is able to do within three lines of code. Although Python comes with very good libraries and allows me to write dense, readable code with the speed of light, the execution time is utterly devastating and it can be quite challenging to use Python for big projects. I want a JVM language for fast prototyping which scales up!

And maybe I found a candidate for this job: Scala is a very weird but interesting language which uses type inference, has a REPL and an interpreter to give you the feeling of working with a true scripting language, but can be used as a fully object-oriented, statically typed Java replacement with good support for design pattern. Scala compiles into bytecode is compatible with Java and unbelievable fast. It even supports functional programming with all those nice concepts like pattern matching, higher order functions, lazy evaluation etc …

A small motivation of why using Scala can be found in a presentation by Martin Odersky. For those who want to directly jump into the code I can recommend the tutorial “First Steps to Scala”, which doesn’t introduces all features, but the most interesting for Java developers who want to try out Scala.

I’ve played around with Scala this evening and rewrote two Java classes I’ve once created for a project at my university. The first class is KohonenNet which allows to train a self-organizing map with a very simple algorithm (I cut out distance modifiers as well as decrements for the learning algorithm for simplicity) and the second one is a helper class Neuron, which is a single feature vector within a two dimensional plane of the Kohonen network. I also removed the comments, to make this post a bit smaller, I hope the code is still readably … The code is quite old and was written, when I wasn’t as experienced in Java programming as I am now, so please be merciful with this code snippet

import java.util.ArrayList;

public class KohonenNet {
  private ArrayList neurons = new ArrayList();

  private int x;
  private int y;

  public KohonenNet(int x, int y, int sizeOfFeatureVector){
    for(int i = 0; i < x; i++){
      for(int j = 0; j < y; j++){
        neurons.add(new Neuron(i,j, sizeOfFeatureVector));
      }
    }
    this.x = x;
    this.y = y;
  }

  public double train(ArrayList stimulus,double learningRate,double activationRadius){
    Neuron bestMatchingUnit = findNearestNeuron(stimulus);
    ArrayList nearestNeighbourhood = getNearestNeighbourhood(bestMatchingUnit, activationRadius);
    updateWeightVectors(stimulus, bestMatchingUnit, nearestNeighbourhood, learningRate);
    return calculateErrorValue(stimulus, bestMatchingUnit);
  }

  private Neuron findNearestNeuron(ArrayList stimulus){
    Neuron result = null;
    double smallestDistance = Double.MAX_VALUE;
    for(int i = 0; i < neurons.size(); i++){
      double newDistance = euclidianDistance(stimulus, neurons.get(i).getWeightVector());
      if(newDistance < smallestDistance){
        result = neurons.get(i);
        smallestDistance = newDistance;
      }
    }
    return result;
  }

  private double euclidianDistance(ArrayList v1, ArrayList v2){
    double result = 0;
    ArrayList difference = new ArrayList();
    for(int i = 0; i < v1.size(); i++){
      double valueFromVec1 = v1.get(i).doubleValue();
      double valueFromVec2 = v2.get(i).doubleValue();
      difference.add(new Double(valueFromVec1 - valueFromVec2));
    }
    for(int i = 0; i < v1.size(); i++){
      double x = difference.get(i).doubleValue();
      result += x * x;
    }
    result = Math.sqrt(result);
    return result;
  }

  private ArrayList getNearestNeighbourhood(Neuron bestMatchingUnit, double activationRadius) {
    ArrayList nearestNeighbourhood = new ArrayList();
    for(Neuron neuron : neurons){
      if(euclidianDistance(bestMatchingUnit.getXYPosition(), neuron.getXYPosition()) <= activationRadius){
        nearestNeighbourhood.add(neuron);
      }
    }
    return nearestNeighbourhood;
  }

  private void updateWeightVectors(ArrayList stimulus,
	  Neuron bestMatchingUnit, ArrayList nearestNeighbourhood, double learningRate) {
    for(Neuron neuron : nearestNeighbourhood ){
      for(int w = 0; w < neuron.getWeightVector().size(); w++){
        double neuronWeight = neuron.getWeightVector().get(w).doubleValue();
        double stimulusWeight = stimulus.get(w).doubleValue();
        neuronWeight = neuronWeight + (learningRate * (stimulusWeight-neuronWeight));
        neuron.getWeightVector().set(w, new Double(neuronWeight));
      }
    }
  }

  private double calculateErrorValue(ArrayList stimulus, Neuron bestMatchingUnit) {
    double distance =  euclidianDistance(stimulus, bestMatchingUnit.getWeightVector());
    return distance / Math.sqrt(stimulus.size());
  }

  public void print(){
    for(int i = 0; i < neurons.size(); i++){
      Neuron actualNeuron = (Neuron)neurons.get(i);
      System.out.print("("+actualNeuron.getPosX()+","+actualNeuron.getPosY()+")[");
      for(int w = 0; w < actualNeuron.getWeightVector().size(); w++){
        System.out.print(actualNeuron.getWeightVector().get(w).doubleValue());
        if(w < actualNeuron.getWeightVector().size()-1){
          System.out.print(",");
        }
      }
      System.out.println("]");
    }
    System.out.println("---");
  }

  public static void main(String[] args) {
    KohonenNet knet = new KohonenNet(2,2,2);
    knet.print();
    ArrayList stim = new ArrayList();
    stim.add(new Double(0.0));
    stim.add(new Double(1.0));
    knet.train(stim, 0.1,1);
    knet.train(stim, 0.1,1);
    double error = knet.train(stim, 0.1,1);
    knet.print();
    System.out.println("Error: " + error);
  }
}

— And here comes the Neuron.java class —

import java.util.ArrayList;

public class Neuron {
  public Neuron(int x, int y, int sizeOfWeightVector){
    posX = x;
    posY = y;
    ArrayList vector = new ArrayList();
    for(int i = 0; i < sizeOfWeightVector; i++){
      weightVector.add(new Double(Math.random()));
    }
  }

  private int posX;
  private int posY;

  private ArrayList weightVector = new ArrayList();

  public ArrayList getXYPosition(){
    ArrayList result = new ArrayList();
    result.add(new Double(posX));
    result.add(new Double(posY));
    return result;
  }

  public int getPosX() {
    return posX;
  }

  public int getPosY() {
    return posY;
  }

  public ArrayList getWeightVector() {
    return weightVector;
  }
}

I compiled the classes with javac *.java and ran the program with java KohonenNet … nothing fancy so far. Now comes the version of the same program in Scala. But with static functions, my Scala program doesn’t use objects at all but rather tries to be fully functional:

object KohonenNet {
  def getRandomNeurons(x:Int, y:Int,sizeOfFeatureVector:Int) =
 for(i <- List.range(0,x); j <- List.range(0,y)) yield Neuron.getRandom(i,j,sizeOfFeatureVector)

  def train(stimulus:List[Double], neurons:List[Neuron],
	learningRate:Double, activationRadius:Double) :Tuple2[List[Neuron], Double] = {
    val bestMatchingUnit = findNearestNeuron(stimulus, neurons)
    val (neighbours,otherNeurons) = getNearestNeighbourhood(bestMatchingUnit, neurons, activationRadius)
    val newNeighbours = updateWeightVectors(stimulus,neighbours,learningRate)
    (newNeighbours ::: otherNeurons, calculateErrorValue(stimulus,bestMatchingUnit) )
  }

  private def findNearestNeuron(stimulus:List[Double], neurons:List[Neuron]) = {
    val sorted = neurons.sort((x, y) =>
      euclideanDistance(stimulus,x.getWeightVector) < euclideanDistance(stimulus,y.getWeightVector))
    sorted(0)
  }

  private def euclideanDistance(k1:List[Double], k2:List[Double]) = {
    val sum = for(i <- List.range(0,k1.length)) yield Math.pow(k1(i) - k2(i), 2)
    Math.sqrt(sum.reduceLeft((x,y) => x+y))
  }

  private def getNearestNeighbourhood(bestMatchingUnit:Neuron,
       neurons:List[Neuron], activationRadius:Double) : Tuple2[List[Neuron], List[Neuron]] =
    neurons.partition((neuron) =>
      euclideanDistance(bestMatchingUnit.getXYPosition, neuron.getXYPosition()) <= activationRadius)

  private def updateWeightVectors(stimulus:List[Double],
	  neighbours:List[Neuron], learningRate:Double) : List[Neuron] =
    neighbours.map((neuron) => updateNeuron(stimulus,neuron,learningRate))

  private def updateNeuron(stimulus:List[Double], neuron:Neuron, learningRate:Double) : Neuron = {
    val oldWeights = neuron.getWeightVector
    val newWeights = for(i <- List.range(0,stimulus.size))
	    yield (oldWeights(i) + (learningRate * (stimulus(i) - oldWeights(i))))
    new Neuron(neuron.getX, neuron.getY, newWeights)
  }

  private def calculateErrorValue(stimulus:List[Double], bestMatchingUnit:Neuron) = {
    val distance =  euclideanDistance(stimulus, bestMatchingUnit.getWeightVector)
    distance / Math.sqrt(stimulus.size)
  }

  def printNeurons(neurons:List[Neuron]) = {
    println("---")
    for(neuron <- neurons) println(neuron)
  }

  def main(args : Array[String]) = {
    var neurons = KohonenNet.getRandomNeurons(2,2,2)
    KohonenNet.printNeurons(neurons)
    var t = KohonenNet.train(List(0,1), neurons, 0.1, 1)
    t = KohonenNet.train(List(0,1), t._1, 0.1, 1)
    t = KohonenNet.train(List(0,1), t._1, 0.1, 1)
    t = KohonenNet.train(List(0,1), t._1, 0.1, 1)
    KohonenNet.printNeurons(t._1)
    val error = t._2
    println("Error" + error)
  }
}

— And here comes the Neuron.scala file —

class Neuron(x: Double, y:Double, weightVector:List[Double]) {
  def getXYPosition() = List(x, y)
  def getX = x
  def getY = y
  def getWeightVector = weightVector
  override def toString = "(" + getX + "," + getY + "/" + getWeightVector + ")"
}

object Neuron {
  def getRandom(x: Int, y:Int, sizeOfWeightVector:Int) =
      new Neuron(x,y,for ( i <- List.range(0,sizeOfWeightVector)) yield Math.random)
}

To compile the code I used the scalac -d bin *.scala command, which creates a bunch of class files and puts them into the bin directory. Then I ran the program with scala -cp bin/ KohonenNet
The Scala code is much smaller compared to the Java code (Java LOC ~95, Scala LOC ~45), but more difficult to read. Maybe it’s the missing training, maybe it’s the constant use of special characters for operators. List comprehension and typical functional methods like reduce, partition or map makes programming in Scala very fast and productive. I needed some time to get used to the syntax, but once the first lines are written everything slowly begins to make sense. The documentation is adequate but could sometimes be more. I like Scala’s idea of parametrize classes to avoid copying them to class member variables.

Advertisements

0 Responses to “Why Scala is worth a try”



  1. Leave a Comment

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s





%d bloggers like this: