amazon emr distributing a python job using amazon aws sdk and Jsch

In this article i will demonstrate how to distribute a python job in amazon emr using amazon aws sdk and Jsch.

If we have standalone java or python code which we want to distribute to different nodes in amazon emr we can use the amazon aws sdk to get the list of core or task nodes and run the standalone code in different nodes . The below example launches a python code in all the available nodes of emr passing in a different input path to each python job so we can achieve parallel processing. Though we recommend using a distributed frameworks such as Apache Spark to do parallel processing this approach can be used if we want to use the existing python code in emr with parallel processing. We will be using JSch which is a java implementation of SSH2 and can be used to connect to an sshd server and execute shell script from java.

Amazon EMR provides a managed Hadoop framework that makes it easy, fast, and cost-effective to process vast amounts of data across dynamically scalable Amazon EC2 instances. You can also run other popular distributed frameworks such as Apache Spark, HBase, Presto, and Flink in EMR, and interact with data in other AWS data stores such as Amazon S3 and Amazon DynamoDB.

Below are the prerequisite for using the below code

  1.  A running amazon emr cluster with ssh enabled in core/task nodes to connect from remote machine.
  2.  A copy of your private key to connect to amazon emr cluster. Amazon EC2 does not keep a copy of your private key so if you lose a private key, there is no way to recover it. So make sure you keep the pem file in safe location.

For more information on creating amazon key pairs refer to the below link

https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-key-pairs.html#having-ec2-create-your-key-pair

Lets start with a code which connects to the amazon cluster and gets public ip of master and core/task nodes. The access key and secret key is configured in the AwsCredentials.properties file.


import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.PropertiesCredentials;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder;
import com.amazonaws.services.elasticmapreduce.model.Cluster;
import com.amazonaws.services.elasticmapreduce.model.DescribeClusterRequest;
import com.amazonaws.services.elasticmapreduce.model.Instance;
import com.amazonaws.services.elasticmapreduce.model.ListInstancesRequest;
import com.amazonaws.services.elasticmapreduce.model.ListInstancesResult;
import com.aws.emr.model.EmrModel;

public class ClusterInfo {

private String clusterId;

public ClusterInfo(String clusterId) {
super();
this.clusterId = clusterId;
}

public String getClusterId() {
return clusterId;
}

public void setClusterId(String clusterId) {
this.clusterId = clusterId;
}

public EmrModel init() throws Exception {
AWSCredentials credentials = null;
try {
credentials = new PropertiesCredentials(
ClusterInfo.class.getClassLoader().getResourceAsStream("AwsCredentials.properties"));
} catch (IOException e1) {

System.out.println(e1.getMessage());
System.exit(-1);
}
AmazonElasticMapReduce client = AmazonElasticMapReduceClientBuilder.standard()
.withCredentials(new AWSStaticCredentialsProvider(credentials)).build();

return new EmrModel(client, getClusterId());

}

public String getMasterNodeIp() throws Exception {

EmrModel model=init();
Cluster cluster = model.getAmazonElasticMapReduce().describeCluster(new DescribeClusterRequest().withClusterId(getClusterId())).getCluster();
ListInstancesResult instances = model.getAmazonElasticMapReduce().listInstances(new ListInstancesRequest().withClusterId(getClusterId()));
String masterDnsName = cluster.getMasterPublicDnsName();
for (Instance instance : instances.getInstances()) {

if (instance.getPublicDnsName().equals(masterDnsName)) {
return instance.getPrivateIpAddress();
}
}
throw new Exception("Failed to find master node private ip.");
}

public List<String> getSlaveNodeIp() throws Exception {
EmrModel model=init();
Cluster cluster = model.getAmazonElasticMapReduce().describeCluster(new DescribeClusterRequest().withClusterId(getClusterId())).getCluster();
ListInstancesResult instances = model.getAmazonElasticMapReduce().listInstances(new ListInstancesRequest().withClusterId(getClusterId()));
String masterDnsName = cluster.getMasterPublicDnsName();
List<String> listOfSlaveNode = new ArrayList<String>();
for (Instance instance : instances.getInstances()) {

if (!instance.getPublicDnsName().equals(masterDnsName)) {
listOfSlaveNode.add(instance.getPublicIpAddress());
}
}
return listOfSlaveNode;
}

}

Below is the EmrModel model class used in the above code


import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce;

public class EmrModel {

private AmazonElasticMapReduce amazonElasticMapReduce;

private String emrId;

public EmrModel(AmazonElasticMapReduce amazonElasticMapReduce, String emrId) {
super();
this.amazonElasticMapReduce = amazonElasticMapReduce;
this.emrId = emrId;
}

public AmazonElasticMapReduce getAmazonElasticMapReduce() {
return amazonElasticMapReduce;
}

public void setAmazonElasticMapReduce(AmazonElasticMapReduce amazonElasticMapReduce) {
this.amazonElasticMapReduce = amazonElasticMapReduce;
}

public String getEmrId() {
return emrId;
}

public void setEmrId(String emrId) {
this.emrId = emrId;
}

}

Below is the java code which uses JSch which is a java implementation of SSH2 and we will connect to the core/task nodes and execute python script from java on different nodes. In the below code the keyname parameter we should pass the .pem file which we downloaded when creating Amazon EC2 Key Pairs. And path refers to the input path for the python script.


import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import com.jcraft.jsch.Channel;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;

public class DistributePythonJob {

private ClusterInfo clusterInfo;

public DistributePythonJob(ClusterInfo clusterInfo) {
super();
this.clusterInfo = clusterInfo;
}

public ClusterInfo getClusterInfo() {
return clusterInfo;
}

public void setClusterInfo(ClusterInfo clusterInfo) {
this.clusterInfo = clusterInfo;
}

public void executePython(String publicDNS, String keyname,String path) throws IOException
{
JSch jSch = new JSch();
try {

jSch.addIdentity(keyname);
Session session = jSch.getSession("hadoop", publicDNS, 22);
java.util.Properties configuration = new java.util.Properties();
configuration.put("StrictHostKeyChecking", "no");
session.setConfig(configuration);
System.out.println("Connecting to new instance " + publicDNS + " via SSH....");
session.connect();
Channel channel=session.openChannel("shell");
channel.setOutputStream(System.out);
File shellScript = getShellScriptFile(path);
FileInputStream fin = new FileInputStream(shellScript);
byte fileContent[] = new byte[(int)shellScript.length()];
fin.read(fileContent);
InputStream in = new ByteArrayInputStream(fileContent);
channel.setInputStream(in);
channel.connect();
fin.close();
}
catch (JSchException e) {
e.printStackTrace();
}
}

public static File getShellScriptFile(String path)
{
String filename = "basic_commands.sh";
File scriptFile = new File(filename);
try{
PrintStream out = new PrintStream(new FileOutputStream(scriptFile));
out.println("echo \"ssh into the instance.\"");
out.println("python readDataFromS3.py "+path);
out.println("exit");
out.close();
}
catch (Exception e)
{
System.err.println("The following error occured: " + e.getMessage());
}
return scriptFile;
}

}

We can execute the above code using the below start up code . In the cluster config file we have configured the cluster id.


public class StartUp {

public static void main(String[] args) throws Exception {

ConfigFile scriptFile = new ConfigFile(Constants.CLUSTER_CONFIG, FileType.property);
String clusterId = scriptFile.getString("clusterId");
String [] path= {"inputPath-1","inputPath-2","inputPath-3"};
ClusterInfo clusterInfo = new ClusterInfo(clusterId);
List<String> listOfNode = clusterInfo.getSlaveNodeIp();
DistributePythonJob distribute = new DistributePythonJob(clusterInfo);

for (String publicDNS : listOfNode) {

distribute.executePython(publicDNS, "test.pem", "path");
}

}

}

Below is the simple python code which downloads the data from amazon s3 to local machine. We are passing the file key and the download location.


import boto3
import sysdef

download_file(file,input_location):

BUCKET_NAME = 'adarsh-test-bucket'
s3 = boto3.resource('s3')
s3.Bucket(BUCKET_NAME).download_file(file,input_location)

def main():
download_file(sys.argv[0],sys.argv[1])
if __name__ == "__main__":
main()