From 247723cee4dc7cf5a20a038a5a9ee9b802c5dd6f Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Thu, 5 Dec 2024 09:58:05 +0800 Subject: [PATCH] fixup --- .../org/apache/gluten/backend/Component.scala | 14 ++++++++------ .../scala/org/apache/gluten/backend/package.scala | 3 ++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/backend/Component.scala b/gluten-core/src/main/scala/org/apache/gluten/backend/Component.scala index 4bad6bab5e06..3f35e44d385d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backend/Component.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backend/Component.scala @@ -81,6 +81,10 @@ object Component { graph.sorted() } + private[backend] def sortedUnsafe(): Seq[Component] = { + graph.sorted() + } + private class Registry { private val lookupByUid: mutable.Map[Int, Component] = mutable.Map() private val lookupByClass: mutable.Map[Class[_ <: Component], Component] = mutable.Map() @@ -119,14 +123,14 @@ object Component { } } - class Graph private[Component] { + private class Graph { import Graph._ private val registry: Registry = new Registry() private val requirements: mutable.Buffer[(Int, Class[_ <: Component])] = mutable.Buffer() private var sortedComponents: Option[Seq[Component]] = None - private[Component] def add(comp: Component): Unit = synchronized { + def add(comp: Component): Unit = synchronized { require( !registry.isUidRegistered(comp.uid), s"Component UID ${comp.uid} already registered: ${comp.name()}") @@ -137,9 +141,7 @@ object Component { sortedComponents = None } - private[Component] def declareRequirement( - comp: Component, - requiredCompClass: Class[_ <: Component]): Unit = + def declareRequirement(comp: Component, requiredCompClass: Class[_ <: Component]): Unit = synchronized { require(registry.isUidRegistered(comp.uid)) require(registry.isClassRegistered(comp.getClass)) @@ -191,7 +193,7 @@ object Component { * requirement from component A to component B. */ // format: on - private[Component] def sorted(): Seq[Component] = synchronized { + def sorted(): Seq[Component] = synchronized { if (sortedComponents.isDefined) { return sortedComponents.get } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backend/package.scala b/gluten-core/src/main/scala/org/apache/gluten/backend/package.scala index 5343ffe922bd..a9981719a333 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backend/package.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backend/package.scala @@ -40,7 +40,8 @@ package object backend extends Logging { all.foreach(_.ensureRegistered()) // Output log so user could view the component loading order. - val components = Component.sorted() + // Call #sortedUnsafe than on #sorted to avoid unnecessary recursion. + val components = Component.sortedUnsafe() logInfo(s"Components registered within order: ${components.map(_.name()).mkString(", ")}") } }