@@ -14,21 +14,47 @@ import com.coder.gateway.sdk.v2.models.Workspace
1414import com.coder.gateway.sdk.v2.models.WorkspaceBuild
1515import com.coder.gateway.sdk.v2.models.WorkspaceTransition
1616import com.coder.gateway.sdk.v2.models.toAgentModels
17+ import com.coder.gateway.services.CoderSettingsState
1718import com.google.gson.Gson
1819import com.google.gson.GsonBuilder
1920import com.intellij.ide.plugins.PluginManagerCore
2021import com.intellij.openapi.components.Service
22+ import com.intellij.openapi.components.service
2123import com.intellij.openapi.extensions.PluginId
2224import com.intellij.openapi.util.SystemInfo
2325import okhttp3.OkHttpClient
26+ import okhttp3.internal.tls.OkHostnameVerifier
2427import okhttp3.logging.HttpLoggingInterceptor
2528import org.zeroturnaround.exec.ProcessExecutor
2629import retrofit2.Retrofit
2730import retrofit2.converter.gson.GsonConverterFactory
31+ import java.io.File
32+ import java.io.FileInputStream
2833import java.net.HttpURLConnection.HTTP_CREATED
34+ import java.net.InetAddress
35+ import java.net.Socket
2936import java.net.URL
37+ import java.security.KeyFactory
38+ import java.security.KeyStore
39+ import java.security.PrivateKey
40+ import java.security.cert.CertificateFactory
41+ import java.security.cert.X509Certificate
42+ import java.security.spec.InvalidKeySpecException
43+ import java.security.spec.PKCS8EncodedKeySpec
3044import java.time.Instant
45+ import java.util.Base64
46+ import java.util.Locale
3147import java.util.UUID
48+ import javax.net.ssl.HostnameVerifier
49+ import javax.net.ssl.KeyManagerFactory
50+ import javax.net.ssl.SNIHostName
51+ import javax.net.ssl.SSLContext
52+ import javax.net.ssl.SSLSession
53+ import javax.net.ssl.SSLSocket
54+ import javax.net.ssl.SSLSocketFactory
55+ import javax.net.ssl.TrustManagerFactory
56+ import javax.net.ssl.TrustManager
57+ import javax.net.ssl.X509TrustManager
3258
3359@Service(Service .Level .APP )
3460class CoderRestClientService {
@@ -66,7 +92,11 @@ class CoderRestClient(var url: URL, var token: String,
6692 pluginVersion = PluginManagerCore .getPlugin(PluginId .getId(" com.coder.gateway" ))!! .version // this is the id from the plugin.xml
6793 }
6894
95+ val socketFactory = coderSocketFactory()
96+ val trustManagers = coderTrustManagers()
6997 httpClient = OkHttpClient .Builder ()
98+ .sslSocketFactory(socketFactory, trustManagers[0 ] as X509TrustManager )
99+ .hostnameVerifier(CoderHostnameVerifier ())
70100 .addInterceptor { it.proceed(it.request().newBuilder().addHeader(" Coder-Session-Token" , token).build()) }
71101 .addInterceptor { it.proceed(it.request().newBuilder().addHeader(" User-Agent" , " Coder Gateway/${pluginVersion} (${SystemInfo .getOsNameAndVersion()} ; ${SystemInfo .OS_ARCH } )" ).build()) }
72102 .addInterceptor {
@@ -218,3 +248,168 @@ class CoderRestClient(var url: URL, var token: String,
218248 }
219249 }
220250}
251+
252+ fun coderSocketFactory () : SSLSocketFactory {
253+ val state: CoderSettingsState = service()
254+
255+ if (state.tlsCertPath.isBlank() || state.tlsKeyPath.isBlank()) {
256+ return SSLSocketFactory .getDefault() as SSLSocketFactory
257+ }
258+
259+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
260+ val certInputStream = FileInputStream (state.tlsCertPath)
261+ val certChain = certificateFactory.generateCertificates(certInputStream)
262+ certInputStream.close()
263+
264+ // ideally we would use something like PemReader from BouncyCastle, but
265+ // BC is used by the IDE. This makes using BC very impractical since
266+ // type casting will mismatch due to the different class loaders.
267+ val privateKeyPem = File (state.tlsKeyPath).readText()
268+ val start: Int = privateKeyPem.indexOf(" -----BEGIN PRIVATE KEY-----" )
269+ val end: Int = privateKeyPem.indexOf(" -----END PRIVATE KEY-----" , start)
270+ val pemBytes: ByteArray = Base64 .getDecoder().decode(
271+ privateKeyPem.substring(start + " -----BEGIN PRIVATE KEY-----" .length, end)
272+ .replace(" \\ s+" .toRegex(), " " )
273+ )
274+
275+ var privateKey : PrivateKey
276+ try {
277+ val kf = KeyFactory .getInstance(" RSA" )
278+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
279+ privateKey = kf.generatePrivate(keySpec)
280+ } catch (e: InvalidKeySpecException ) {
281+ val kf = KeyFactory .getInstance(" EC" )
282+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
283+ privateKey = kf.generatePrivate(keySpec)
284+ }
285+
286+ val keyStore = KeyStore .getInstance(KeyStore .getDefaultType())
287+ keyStore.load(null )
288+ certChain.withIndex().forEach {
289+ keyStore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
290+ }
291+ keyStore.setKeyEntry(" key" , privateKey, null , certChain.toTypedArray())
292+
293+ val keyManagerFactory = KeyManagerFactory .getInstance(KeyManagerFactory .getDefaultAlgorithm())
294+ keyManagerFactory.init (keyStore, null )
295+
296+ val sslContext = SSLContext .getInstance(" TLS" )
297+
298+ val trustManagers = coderTrustManagers()
299+ sslContext.init (keyManagerFactory.keyManagers, trustManagers, null )
300+
301+ if (state.tlsAlternateHostname.isBlank()) {
302+ return sslContext.socketFactory
303+ }
304+
305+ return AlternateNameSSLSocketFactory (sslContext.socketFactory, state.tlsAlternateHostname)
306+ }
307+
308+ fun coderTrustManagers () : Array <TrustManager > {
309+ val state: CoderSettingsState = service()
310+
311+ val trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())
312+ if (state.tlsCAPath.isBlank()) {
313+ // return default trust managers
314+ trustManagerFactory.init (null as KeyStore ? )
315+ return trustManagerFactory.trustManagers
316+ }
317+
318+
319+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
320+ val caInputStream = FileInputStream (state.tlsCAPath)
321+ val certChain = certificateFactory.generateCertificates(caInputStream)
322+
323+ val truststore = KeyStore .getInstance(KeyStore .getDefaultType())
324+ truststore.load(null )
325+ certChain.withIndex().forEach {
326+ truststore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
327+ }
328+ trustManagerFactory.init (truststore)
329+ return trustManagerFactory.trustManagers
330+ }
331+
332+ class AlternateNameSSLSocketFactory (private val delegate : SSLSocketFactory , private val alternateName : String ) : SSLSocketFactory() {
333+ override fun getDefaultCipherSuites (): Array <String > {
334+ return delegate.defaultCipherSuites
335+ }
336+
337+ override fun getSupportedCipherSuites (): Array <String > {
338+ return delegate.supportedCipherSuites
339+ }
340+
341+ override fun createSocket (): Socket {
342+ val socket = delegate.createSocket() as SSLSocket
343+ customizeSocket(socket)
344+ return socket
345+ }
346+
347+ override fun createSocket (host : String? , port : Int ): Socket {
348+ val socket = delegate.createSocket(host, port) as SSLSocket
349+ customizeSocket(socket)
350+ return socket
351+ }
352+
353+ override fun createSocket (host : String? , port : Int , localHost : InetAddress ? , localPort : Int ): Socket {
354+ val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
355+ customizeSocket(socket)
356+ return socket
357+ }
358+
359+ override fun createSocket (host : InetAddress ? , port : Int ): Socket {
360+ val socket = delegate.createSocket(host, port) as SSLSocket
361+ customizeSocket(socket)
362+ return socket
363+ }
364+
365+ override fun createSocket (address : InetAddress ? , port : Int , localAddress : InetAddress ? , localPort : Int ): Socket {
366+ val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
367+ customizeSocket(socket)
368+ return socket
369+ }
370+
371+ override fun createSocket (s : Socket ? , host : String? , port : Int , autoClose : Boolean ): Socket {
372+ val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
373+ customizeSocket(socket)
374+ return socket
375+ }
376+
377+ private fun customizeSocket (socket : SSLSocket ) {
378+ val params = socket.sslParameters
379+ params.serverNames = listOf (SNIHostName (alternateName))
380+ socket.sslParameters = params
381+ }
382+ }
383+
384+ class CoderHostnameVerifier () : HostnameVerifier {
385+ private val alternateName: String
386+
387+ init {
388+ val state: CoderSettingsState = service()
389+ this .alternateName = state.tlsAlternateHostname.lowercase(Locale .getDefault())
390+ }
391+
392+ override fun verify (host : String , session : SSLSession ): Boolean {
393+ if (alternateName.isEmpty()) {
394+ return OkHostnameVerifier .verify(host, session)
395+ }
396+ val certs = session.peerCertificates ? : return false
397+ for (cert in certs) {
398+ if (cert !is X509Certificate ) {
399+ continue
400+ }
401+ val entries = cert.subjectAlternativeNames ? : continue
402+ for (entry in entries) {
403+ val kind = entry[0 ] as Int
404+ if (kind != 2 ) { // DNS Name
405+ continue
406+ }
407+ val hostname = entry[1 ] as String
408+ if (hostname.lowercase(Locale .getDefault()) == alternateName) {
409+ return true
410+ }
411+ }
412+ }
413+ return false
414+ }
415+ }
0 commit comments