@@ -14,21 +14,48 @@ import com.coder.gateway.sdk.v2.models.Workspace
1414import com.coder.gateway.sdk.v2.models.WorkspaceBuild
1515import com.coder.gateway.sdk.v2.models.WorkspaceTransition
1616import com.coder.gateway.sdk.v2.models.toAgentModels
17+ import com.coder.gateway.services.CoderSettingsState
1718import com.google.gson.Gson
1819import com.google.gson.GsonBuilder
1920import com.intellij.ide.plugins.PluginManagerCore
2021import com.intellij.openapi.components.Service
2122import com.intellij.openapi.extensions.PluginId
2223import com.intellij.openapi.util.SystemInfo
2324import okhttp3.OkHttpClient
25+ import okhttp3.internal.tls.OkHostnameVerifier
2426import okhttp3.logging.HttpLoggingInterceptor
2527import org.zeroturnaround.exec.ProcessExecutor
2628import retrofit2.Retrofit
2729import retrofit2.converter.gson.GsonConverterFactory
30+ import java.io.File
31+ import java.io.FileInputStream
2832import java.net.HttpURLConnection.HTTP_CREATED
33+ import java.net.InetAddress
34+ import java.net.Socket
2935import java.net.URL
36+ import java.nio.file.Path
37+ import java.security.KeyFactory
38+ import java.security.KeyStore
39+ import java.security.PrivateKey
40+ import java.security.cert.CertificateException
41+ import java.security.cert.CertificateFactory
42+ import java.security.cert.X509Certificate
43+ import java.security.spec.InvalidKeySpecException
44+ import java.security.spec.PKCS8EncodedKeySpec
3045import java.time.Instant
46+ import java.util.Base64
47+ import java.util.Locale
3148import java.util.UUID
49+ import javax.net.ssl.HostnameVerifier
50+ import javax.net.ssl.KeyManagerFactory
51+ import javax.net.ssl.SNIHostName
52+ import javax.net.ssl.SSLContext
53+ import javax.net.ssl.SSLSession
54+ import javax.net.ssl.SSLSocket
55+ import javax.net.ssl.SSLSocketFactory
56+ import javax.net.ssl.TrustManagerFactory
57+ import javax.net.ssl.TrustManager
58+ import javax.net.ssl.X509TrustManager
3259
3360@Service(Service .Level .APP )
3461class CoderRestClientService {
@@ -44,18 +71,19 @@ class CoderRestClientService {
4471 *
4572 * @throws [AuthenticationResponseException] if authentication failed.
4673 */
47- fun initClientSession (url : URL , token : String , headerCommand : String? ): User {
48- client = CoderRestClient (url, token, headerCommand, null )
74+ fun initClientSession (url : URL , token : String , settings : CoderSettingsState ): User {
75+ client = CoderRestClient (url, token, null , settings )
4976 me = client.me()
5077 buildVersion = client.buildInfo().version
5178 isReady = true
5279 return me
5380 }
5481}
5582
56- class CoderRestClient (var url : URL , var token : String ,
57- private var headerCommand : String? ,
83+ class CoderRestClient (
84+ var url : URL , var token : String ,
5885 private var pluginVersion : String? ,
86+ private var settings : CoderSettingsState ,
5987) {
6088 private var httpClient: OkHttpClient
6189 private var retroRestClient: CoderV2RestFacade
@@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String,
6694 pluginVersion = PluginManagerCore .getPlugin(PluginId .getId(" com.coder.gateway" ))!! .version // this is the id from the plugin.xml
6795 }
6896
97+ val socketFactory = coderSocketFactory(settings)
98+ val trustManagers = coderTrustManagers(settings.tlsCAPath)
6999 httpClient = OkHttpClient .Builder ()
100+ .sslSocketFactory(socketFactory, trustManagers[0 ] as X509TrustManager )
101+ .hostnameVerifier(CoderHostnameVerifier (settings.tlsAlternateHostname))
70102 .addInterceptor { it.proceed(it.request().newBuilder().addHeader(" Coder-Session-Token" , token).build()) }
71103 .addInterceptor { it.proceed(it.request().newBuilder().addHeader(" User-Agent" , " Coder Gateway/${pluginVersion} (${SystemInfo .getOsNameAndVersion()} ; ${SystemInfo .OS_ARCH } )" ).build()) }
72104 .addInterceptor {
73105 var request = it.request()
74- val headers = getHeaders(url, headerCommand)
106+ val headers = getHeaders(url, settings. headerCommand)
75107 if (headers.size > 0 ) {
76108 val builder = request.newBuilder()
77109 headers.forEach { h -> builder.addHeader(h.key, h.value) }
@@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String,
218250 }
219251 }
220252}
253+
254+ fun coderSocketFactory (settings : CoderSettingsState ) : SSLSocketFactory {
255+ if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) {
256+ return SSLSocketFactory .getDefault() as SSLSocketFactory
257+ }
258+
259+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
260+ val certInputStream = FileInputStream (expandPath(settings.tlsCertPath))
261+ val certChain = certificateFactory.generateCertificates(certInputStream)
262+ certInputStream.close()
263+
264+ // ideally we would use something like PemReader from BouncyCastle, but
265+ // BC is used by the IDE. This makes using BC very impractical since
266+ // type casting will mismatch due to the different class loaders.
267+ val privateKeyPem = File (expandPath(settings.tlsKeyPath)).readText()
268+ val start: Int = privateKeyPem.indexOf(" -----BEGIN PRIVATE KEY-----" )
269+ val end: Int = privateKeyPem.indexOf(" -----END PRIVATE KEY-----" , start)
270+ val pemBytes: ByteArray = Base64 .getDecoder().decode(
271+ privateKeyPem.substring(start + " -----BEGIN PRIVATE KEY-----" .length, end)
272+ .replace(" \\ s+" .toRegex(), " " )
273+ )
274+
275+ var privateKey : PrivateKey
276+ try {
277+ val kf = KeyFactory .getInstance(" RSA" )
278+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
279+ privateKey = kf.generatePrivate(keySpec)
280+ } catch (e: InvalidKeySpecException ) {
281+ val kf = KeyFactory .getInstance(" EC" )
282+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
283+ privateKey = kf.generatePrivate(keySpec)
284+ }
285+
286+ val keyStore = KeyStore .getInstance(KeyStore .getDefaultType())
287+ keyStore.load(null )
288+ certChain.withIndex().forEach {
289+ keyStore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
290+ }
291+ keyStore.setKeyEntry(" key" , privateKey, null , certChain.toTypedArray())
292+
293+ val keyManagerFactory = KeyManagerFactory .getInstance(KeyManagerFactory .getDefaultAlgorithm())
294+ keyManagerFactory.init (keyStore, null )
295+
296+ val sslContext = SSLContext .getInstance(" TLS" )
297+
298+ val trustManagers = coderTrustManagers(settings.tlsCAPath)
299+ sslContext.init (keyManagerFactory.keyManagers, trustManagers, null )
300+
301+ if (settings.tlsAlternateHostname.isBlank()) {
302+ return sslContext.socketFactory
303+ }
304+
305+ return AlternateNameSSLSocketFactory (sslContext.socketFactory, settings.tlsAlternateHostname)
306+ }
307+
308+ fun coderTrustManagers (tlsCAPath : String ) : Array <TrustManager > {
309+ val trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())
310+ if (tlsCAPath.isBlank()) {
311+ // return default trust managers
312+ trustManagerFactory.init (null as KeyStore ? )
313+ return trustManagerFactory.trustManagers
314+ }
315+
316+
317+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
318+ val caInputStream = FileInputStream (expandPath(tlsCAPath))
319+ val certChain = certificateFactory.generateCertificates(caInputStream)
320+
321+ val truststore = KeyStore .getInstance(KeyStore .getDefaultType())
322+ truststore.load(null )
323+ certChain.withIndex().forEach {
324+ truststore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
325+ }
326+ trustManagerFactory.init (truststore)
327+ return trustManagerFactory.trustManagers.map { MergedSystemTrustManger (it as X509TrustManager ) }.toTypedArray()
328+ }
329+
330+ fun expandPath (path : String ): String {
331+ if (path.startsWith(" ~/" )) {
332+ return Path .of(System .getProperty(" user.home" ), path.substring(1 )).toString()
333+ }
334+ if (path.startsWith(" \$ HOME/" )) {
335+ return Path .of(System .getProperty(" user.home" ), path.substring(5 )).toString()
336+ }
337+ if (path.startsWith(" \$ {user.home}/" )) {
338+ return Path .of(System .getProperty(" user.home" ), path.substring(12 )).toString()
339+ }
340+ return path
341+ }
342+
343+ class AlternateNameSSLSocketFactory (private val delegate : SSLSocketFactory , private val alternateName : String ) : SSLSocketFactory() {
344+ override fun getDefaultCipherSuites (): Array <String > {
345+ return delegate.defaultCipherSuites
346+ }
347+
348+ override fun getSupportedCipherSuites (): Array <String > {
349+ return delegate.supportedCipherSuites
350+ }
351+
352+ override fun createSocket (): Socket {
353+ val socket = delegate.createSocket() as SSLSocket
354+ customizeSocket(socket)
355+ return socket
356+ }
357+
358+ override fun createSocket (host : String? , port : Int ): Socket {
359+ val socket = delegate.createSocket(host, port) as SSLSocket
360+ customizeSocket(socket)
361+ return socket
362+ }
363+
364+ override fun createSocket (host : String? , port : Int , localHost : InetAddress ? , localPort : Int ): Socket {
365+ val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
366+ customizeSocket(socket)
367+ return socket
368+ }
369+
370+ override fun createSocket (host : InetAddress ? , port : Int ): Socket {
371+ val socket = delegate.createSocket(host, port) as SSLSocket
372+ customizeSocket(socket)
373+ return socket
374+ }
375+
376+ override fun createSocket (address : InetAddress ? , port : Int , localAddress : InetAddress ? , localPort : Int ): Socket {
377+ val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
378+ customizeSocket(socket)
379+ return socket
380+ }
381+
382+ override fun createSocket (s : Socket ? , host : String? , port : Int , autoClose : Boolean ): Socket {
383+ val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
384+ customizeSocket(socket)
385+ return socket
386+ }
387+
388+ private fun customizeSocket (socket : SSLSocket ) {
389+ val params = socket.sslParameters
390+ params.serverNames = listOf (SNIHostName (alternateName))
391+ socket.sslParameters = params
392+ }
393+ }
394+
395+ class CoderHostnameVerifier (private val alternateName : String ) : HostnameVerifier {
396+ override fun verify (host : String , session : SSLSession ): Boolean {
397+ if (alternateName.isEmpty()) {
398+ println (" using default hostname verifier, alternateName is empty" )
399+ return OkHostnameVerifier .verify(host, session)
400+ }
401+ println (" Looking for alternate hostname: $alternateName " )
402+ val certs = session.peerCertificates ? : return false
403+ for (cert in certs) {
404+ if (cert !is X509Certificate ) {
405+ continue
406+ }
407+ val entries = cert.subjectAlternativeNames ? : continue
408+ for (entry in entries) {
409+ val kind = entry[0 ] as Int
410+ if (kind != 2 ) { // DNS Name
411+ continue
412+ }
413+ val hostname = entry[1 ] as String
414+ println (" Found cert hostname: $hostname " )
415+ if (hostname.lowercase(Locale .getDefault()) == alternateName) {
416+ return true
417+ }
418+ }
419+ }
420+ println (" No matching hostname found" )
421+ return false
422+ }
423+ }
424+
425+ class MergedSystemTrustManger (private val otherTrustManager : X509TrustManager ) : X509TrustManager {
426+ private val systemTrustManager : X509TrustManager
427+ init {
428+ val trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())
429+ trustManagerFactory.init (null as KeyStore ? )
430+ systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager
431+ }
432+
433+ override fun checkClientTrusted (chain : Array <out X509Certificate >, authType : String? ) {
434+ try {
435+ otherTrustManager.checkClientTrusted(chain, authType)
436+ } catch (e: CertificateException ) {
437+ systemTrustManager.checkClientTrusted(chain, authType)
438+ }
439+ }
440+
441+ override fun checkServerTrusted (chain : Array <out X509Certificate >, authType : String? ) {
442+ try {
443+ otherTrustManager.checkServerTrusted(chain, authType)
444+ } catch (e: CertificateException ) {
445+ systemTrustManager.checkServerTrusted(chain, authType)
446+ }
447+ }
448+
449+ override fun getAcceptedIssuers (): Array <X509Certificate > {
450+ return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers
451+ }
452+ }
0 commit comments