Scala code to get a secret stored in Azure key vault from databricks

In this article, we will see an example to access a secret stored in an Azure key vault from a databricks notebook using scala.

We will need a PEM file to access a secret in key vault programmatically which can be downloaded from the azure site, navigate to the key vault service, and click on certificates to get an option to download the PEM file which can be used to connect to the key vault.

Along with the PEM file, we also need the azure key vault URL and client id of the databricks resource which can be copied from the Enterprise application service. Navigate to the Enterprise application service and search for the databricks application and click on the overview tab to copy the application ID.

Let’s upload the PEM file downloaded into the DBFS using the databricks cli. If you don`t have the databricks cli installed we can install the same using the below URL

https://docs.microsoft.com/en-us/azure/databricks/dev-tools/cli/

Once we have the databricks cli installed enter the below command, to connect to the dbfs

databricks configure --token

In case you don`t have the token the same can be copied from the databricks notebook navigate to the user settings and click on Generate New Token. Once the databricks cli is connected to the DBFS we can use the below command to upload the PEM file into the DBFS

dbfs cp -r directory_where_PEM_file_exists dbfs:/certificate_path

Below is the scala code which can be run in the databricks notebook or in any scala project running in databricks, which connects to the azure key vault and gets the value of a secret stored in the key vault.

import java.io.{File, FileReader}
import java.security.cert.X509Certificate
import java.security.{PrivateKey, Security}
import java.util.concurrent.Executors
import com.microsoft.aad.adal4j.{AsymmetricKeyCredential, AuthenticationContext}
import com.microsoft.azure.keyvault.KeyVaultClient
import com.microsoft.azure.keyvault.authentication.KeyVaultCredentials
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.openssl.PEMParser
import org.bouncycastle.openssl.jcajce.{JcaPEMKeyConverter, JceOpenSSLPKCS8DecryptorProviderBuilder}
import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo


object ReadAzureKeyVaultSecretFromDataBricks extends App{

   def readPemFile(filePath: String, password: String):X509CertificateKey = {
    Security.addProvider(new BouncyCastleProvider)
    val pemParserInstance = new PEMParser(new FileReader(new File(filePath)))
    var privateKey: PrivateKey = null
    var x509Certificate: X509Certificate = null
    var parsedObjInstance = pemParserInstance.readObject
    while (parsedObjInstance != null) {
      val keyConverter = new JcaPEMKeyConverter().setProvider("BC")
      parsedObjInstance match {
        case x509CertificateHolder: X509CertificateHolder => x509Certificate = new JcaX509CertificateConverter().getCertificate(x509CertificateHolder)
        case pKCS8EncryptedPrivateKeyInfo: PKCS8EncryptedPrivateKeyInfo =>
          val inputDecryptorProvider = new JceOpenSSLPKCS8DecryptorProviderBuilder().build(password.toCharArray)
          val privateKeyInfo = pKCS8EncryptedPrivateKeyInfo.decryptPrivateKeyInfo(inputDecryptorProvider)
          privateKey = keyConverter.getPrivateKey(privateKeyInfo)
        case info: PrivateKeyInfo => privateKey = keyConverter.getPrivateKey(info)
      }
      parsedObjInstance = pemParserInstance.readObject
    }
    pemParserInstance.close()
    X509CertificateKey(x509Certificate, privateKey)
  }


  case class X509CertificateKey(x509Certificate: X509Certificate, privateKey: PrivateKey)

  class AzureKeyVaultCredential(clientId: String, path: String, pemPassword: String = "") extends KeyVaultCredentials {

    private val x509CertificateKey = ReadAzureKeyVaultSecretFromDataBricks.readPemFile(path, pemPassword)

    override def doAuthenticate(authorization: String, resource: String, scope: String): String = {
      val authenticationContext = new AuthenticationContext(authorization, false, Executors.newFixedThreadPool(1))
      val keyCredential = AsymmetricKeyCredential.create(clientId, x509CertificateKey.privateKey, x509CertificateKey.x509Certificate)
      val authenticationResult = authenticationContext.acquireToken(resource, keyCredential, null).get
      authenticationResult.getAccessToken
    }
  }

  class KeyVaultUtility(keyVaultCredentials: KeyVaultCredentials, defaultKeyVaultUri: String) {
    val client = new KeyVaultClient(keyVaultCredentials)

    def getValueInKv (keyVaultUri: String) (name: String): String = {
      val secretBundle = client.getSecret(keyVaultUri, name)
      secretBundle.value()
    }

    val getValue = getValueInKv (defaultKeyVaultUri) _
  }

  val azureClientID = "ENTER_AZURE_CLIENT_ID"
  val dbfsPathToPEMFile = "ENTER_DBFS_PATH_FOR_PEM_FILE"
  val azure_key_vault_url = "ENTER_KEY_VAULT_URL"
  val keyVaultUtility = new KeyVaultUtility(new AzureKeyVaultCredential(azureClientID, dbfsPathToPEMFile), azure_key_vault_url)
  keyVaultUtility.getValue("Enter_KEY_TO_GET_SECRET")

}