コンテンツにスキップ

gRPC-KotlinのServerInterceptorでCoroutinesを使えるようにする

問題

gRPC-KotlinはgRPC-Javaをベースに実装されていて、Kotlin Coroutinesを用いてgRPCサーバーを実装することができます。

ただし、Interceptorの実装においてはCoroutines対応が残念ながらされていません。 Interceptorを使って認証処理(I/O処理を伴う)などを実装するケースがあるため、本来はここもCoroutinesを使って書けると嬉しいです。
また、InterceptorはExecutors.newCachedThreadPoolで実行されるため、blockingな処理を実装しても問題はありませんが、可能であればCoroutinesを使って書きたいです。

この問題はこのIssueでも議論されています。
https://github.com/grpc/grpc-kotlin/issues/223

解決策

SuspendableServerInterceptor というabstract classを用意し、これを使用するようにします。

/**
 * https://stackoverflow.com/questions/53651024/grpc-java-async-call-in-serverinterceptor
 */
abstract class SuspendableServerInterceptor(
    private val context: CoroutineContext = EmptyCoroutineContext
) : ServerInterceptor {
    override fun <ReqT : Any, RespT : Any> interceptCall(
        call: ServerCall<ReqT, RespT>,
        headers: Metadata,
        next: ServerCallHandler<ReqT, RespT>
    ): ServerCall.Listener<ReqT> {
        val delayedListener = DelayedListener<ReqT>()
        delayedListener.job = CoroutineScope(
            GrpcContextElement.current()
                    + COROUTINE_CONTEXT_KEY.get()
                    + context
        ).launch {
            try {
                delayedListener.realListener = suspendableInterceptCall(call, headers, next)
                delayedListener.drainPendingCallbacks()
            } catch (e: CancellationException) {
                log.debug { "Caught CancellationException. $e" }
                call.close(Status.CANCELLED, Metadata())
            } catch (e: Exception) {
                log.error(e) { "Unhandled exception. $e" }
                call.close(Status.UNKNOWN, Metadata())
            }
        }
        return delayedListener
    }

    abstract suspend fun <ReqT : Any, RespT : Any> suspendableInterceptCall(
        call: ServerCall<ReqT, RespT>,
        headers: Metadata,
        next: ServerCallHandler<ReqT, RespT>
    ): ServerCall.Listener<ReqT>

    /**
     * ref: https://github.com/grpc/grpc-java/blob/84edc332397ed01fae2400c25196fc90d8c1a6dd/core/src/main/java/io/grpc/internal/DelayedClientCall.java#L415
     */
    private class DelayedListener<ReqT> : ServerCall.Listener<ReqT>() {
        var realListener: ServerCall.Listener<ReqT>? = null

        @Volatile
        private var passThrough = false

        @GuardedBy("this")
        private var pendingCallbacks: MutableList<Runnable> = mutableListOf()

        var job: Job? = null

        override fun onMessage(message: ReqT) {
            if (passThrough) {
                checkNotNull(realListener).onMessage(message)
            } else {
                delayOrExecute { checkNotNull(realListener).onMessage(message) }
            }
        }

        override fun onHalfClose() {
            if (passThrough) {
                checkNotNull(realListener).onHalfClose()
            } else {
                delayOrExecute { checkNotNull(realListener).onHalfClose() }
            }
        }

        override fun onCancel() {
            job?.cancel()
            if (passThrough) {
                checkNotNull(realListener).onCancel()
            } else {
                delayOrExecute { checkNotNull(realListener).onCancel() }
            }
        }

        override fun onComplete() {
            if (passThrough) {
                checkNotNull(realListener).onComplete()
            } else {
                delayOrExecute { checkNotNull(realListener).onComplete() }
            }
        }

        override fun onReady() {
            if (passThrough) {
                checkNotNull(realListener).onReady()
            } else {
                delayOrExecute { checkNotNull(realListener).onReady() }
            }
        }

        private fun delayOrExecute(runnable: Runnable) {
            synchronized(this) {
                if (!passThrough) {
                    pendingCallbacks.add(runnable)
                    return
                }
            }
            runnable.run()
        }

        fun drainPendingCallbacks() {
            check(!passThrough)
            var toRun: MutableList<Runnable> = mutableListOf()
            while (true) {
                synchronized(this) {
                    if (pendingCallbacks.isEmpty()) {
                        pendingCallbacks = mutableListOf()
                        passThrough = true
                        return
                    }
                    // Since there were pendingCallbacks, we need to process them. To maintain ordering we
                    // can't set passThrough=true until we run all pendingCallbacks, but new Runnables may be
                    // added after we drop the lock. So we will have to re-check pendingCallbacks.
                    val tmp: MutableList<Runnable> = toRun
                    toRun = pendingCallbacks
                    pendingCallbacks = tmp
                }
                for (runnable in toRun) {
                    // Avoid calling listener while lock is held to prevent deadlocks.
                    runnable.run()
                }
                toRun.clear()
            }
        }
    }

    companion object {
        private val log = KotlinLogging.logger {}

        @Suppress("UNCHECKED_CAST")
        // Get by using reflection
        internal val COROUTINE_CONTEXT_KEY: Context.Key<CoroutineContext> =
            CoroutineContextServerInterceptor::class.let { kclass ->
                val companionObject = kclass.companionObject!!
                val property = companionObject.memberProperties.single { it.name == "COROUTINE_CONTEXT_KEY" }
                checkNotNull(property.getter.call(kclass.companionObjectInstance!!)) as Context.Key<CoroutineContext>
            }
    }
}

ここで紹介されている方法を使ったものです。
grpc-java async call in ServerInterceptor

このクラスを継承し、suspendableInterceptCallをオーバーライドする形で実装すればOKです。


最終更新日: 2022/06/06 21:53