In this article, we will see an example to access a secret stored in an Azure key vault from a databricks notebook using scala.
We will need a PEM file to access a secret in key vault programmatically which can be downloaded from the azure site, navigate to the key vault service, and click on certificates to get an option to download the PEM file which can be used to connect to the key vault.
Along with the PEM file, we also need the azure key vault URL and client id of the databricks resource which can be copied from the Enterprise application service. Navigate to the Enterprise application service and search for the databricks application and click on the overview tab to copy the application ID.
Let’s upload the PEM file downloaded into the DBFS using the databricks cli. If you don`t have the databricks cli installed we can install the same using the below URL
https://docs.microsoft.com/en-us/azure/databricks/dev-tools/cli/
Once we have the databricks cli installed enter the below command, to connect to the dbfs
databricks configure --token
In case you don`t have the token the same can be copied from the databricks notebook navigate to the user settings and click on Generate New Token. Once the databricks cli is connected to the DBFS we can use the below command to upload the PEM file into the DBFS
dbfs cp -r directory_where_PEM_file_exists dbfs:/certificate_path
Below is the scala code which can be run in the databricks notebook or in any scala project running in databricks, which connects to the azure key vault and gets the value of a secret stored in the key vault.
import java.io.{File, FileReader} import java.security.cert.X509Certificate import java.security.{PrivateKey, Security} import java.util.concurrent.Executors import com.microsoft.aad.adal4j.{AsymmetricKeyCredential, AuthenticationContext} import com.microsoft.azure.keyvault.KeyVaultClient import com.microsoft.azure.keyvault.authentication.KeyVaultCredentials import org.bouncycastle.asn1.pkcs.PrivateKeyInfo import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.openssl.PEMParser import org.bouncycastle.openssl.jcajce.{JcaPEMKeyConverter, JceOpenSSLPKCS8DecryptorProviderBuilder} import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo object ReadAzureKeyVaultSecretFromDataBricks extends App{ def readPemFile(filePath: String, password: String):X509CertificateKey = { Security.addProvider(new BouncyCastleProvider) val pemParserInstance = new PEMParser(new FileReader(new File(filePath))) var privateKey: PrivateKey = null var x509Certificate: X509Certificate = null var parsedObjInstance = pemParserInstance.readObject while (parsedObjInstance != null) { val keyConverter = new JcaPEMKeyConverter().setProvider("BC") parsedObjInstance match { case x509CertificateHolder: X509CertificateHolder => x509Certificate = new JcaX509CertificateConverter().getCertificate(x509CertificateHolder) case pKCS8EncryptedPrivateKeyInfo: PKCS8EncryptedPrivateKeyInfo => val inputDecryptorProvider = new JceOpenSSLPKCS8DecryptorProviderBuilder().build(password.toCharArray) val privateKeyInfo = pKCS8EncryptedPrivateKeyInfo.decryptPrivateKeyInfo(inputDecryptorProvider) privateKey = keyConverter.getPrivateKey(privateKeyInfo) case info: PrivateKeyInfo => privateKey = keyConverter.getPrivateKey(info) } parsedObjInstance = pemParserInstance.readObject } pemParserInstance.close() X509CertificateKey(x509Certificate, privateKey) } case class X509CertificateKey(x509Certificate: X509Certificate, privateKey: PrivateKey) class AzureKeyVaultCredential(clientId: String, path: String, pemPassword: String = "") extends KeyVaultCredentials { private val x509CertificateKey = ReadAzureKeyVaultSecretFromDataBricks.readPemFile(path, pemPassword) override def doAuthenticate(authorization: String, resource: String, scope: String): String = { val authenticationContext = new AuthenticationContext(authorization, false, Executors.newFixedThreadPool(1)) val keyCredential = AsymmetricKeyCredential.create(clientId, x509CertificateKey.privateKey, x509CertificateKey.x509Certificate) val authenticationResult = authenticationContext.acquireToken(resource, keyCredential, null).get authenticationResult.getAccessToken } } class KeyVaultUtility(keyVaultCredentials: KeyVaultCredentials, defaultKeyVaultUri: String) { val client = new KeyVaultClient(keyVaultCredentials) def getValueInKv (keyVaultUri: String) (name: String): String = { val secretBundle = client.getSecret(keyVaultUri, name) secretBundle.value() } val getValue = getValueInKv (defaultKeyVaultUri) _ } val azureClientID = "ENTER_AZURE_CLIENT_ID" val dbfsPathToPEMFile = "ENTER_DBFS_PATH_FOR_PEM_FILE" val azure_key_vault_url = "ENTER_KEY_VAULT_URL" val keyVaultUtility = new KeyVaultUtility(new AzureKeyVaultCredential(azureClientID, dbfsPathToPEMFile), azure_key_vault_url) keyVaultUtility.getValue("Enter_KEY_TO_GET_SECRET") }