diff --git a/Source/Common/TaskGraph/imstkTaskGraph.cpp b/Source/Common/TaskGraph/imstkTaskGraph.cpp index 09768ea7f2f40bf7eafdb23f50ecc8afc6b904e6..3c73024463eab9e7a7e0648830498a101946b232 100644 --- a/Source/Common/TaskGraph/imstkTaskGraph.cpp +++ b/Source/Common/TaskGraph/imstkTaskGraph.cpp @@ -458,41 +458,44 @@ TaskGraph::resolveCriticalNodes(std::shared_ptr<TaskGraph> graph) CHECK(graph != nullptr) << "Graph is nullptr"; std::shared_ptr<TaskGraph> results = std::make_shared<TaskGraph>(*graph); - const TaskNodeAdjList& adjList = graph->getAdjList(); - const TaskNodeVector& nodes = graph->getNodes(); + const TaskNodeAdjList& adjList = results->getAdjList(); + const TaskNodeVector& nodes = results->getNodes(); // Compute the levels of each node via DFS std::unordered_map<std::shared_ptr<TaskNode>, int> depths; - { - std::unordered_set<std::shared_ptr<TaskNode>> visitedNodes; - - // DFS for the dependencies - std::stack<std::shared_ptr<TaskNode>> nodeStack; - depths[graph->getSource()] = 0; - nodeStack.push(graph->getSource()); - while (!nodeStack.empty()) + auto computeDepths = + [&]() { - std::shared_ptr<TaskNode> currNode = nodeStack.top(); - int currLevel = depths[currNode]; - nodeStack.pop(); + std::unordered_set<std::shared_ptr<TaskNode>> visitedNodes; - // Add children to stack if not yet visited - if (adjList.count(currNode) != 0) + // DFS for the dependencies + std::stack<std::shared_ptr<TaskNode>> nodeStack; + depths[results->getSource()] = 0; + nodeStack.push(results->getSource()); + while (!nodeStack.empty()) { - const TaskNodeSet& outputNodes = adjList.at(currNode); - for (TaskNodeSet::const_iterator i = outputNodes.begin(); i != outputNodes.end(); i++) + std::shared_ptr<TaskNode> currNode = nodeStack.top(); + int currLevel = depths[currNode]; + nodeStack.pop(); + + // Add children to stack if not yet visited + if (adjList.count(currNode) != 0) { - std::shared_ptr<TaskNode> childNode = *i; - if (visitedNodes.count(childNode) == 0) + const TaskNodeSet& outputNodes = adjList.at(currNode); + for (TaskNodeSet::const_iterator i = outputNodes.begin(); i != outputNodes.end(); i++) { - visitedNodes.insert(childNode); - depths[childNode] = currLevel + 1; - nodeStack.push(childNode); + std::shared_ptr<TaskNode> childNode = *i; + if (visitedNodes.count(childNode) == 0) + { + visitedNodes.insert(childNode); + depths[childNode] = currLevel + 1; + nodeStack.push(childNode); + } } } } - } - } + }; + computeDepths(); // Identify the set of critical nodes TaskNodeVector critNodes; @@ -505,43 +508,56 @@ TaskGraph::resolveCriticalNodes(std::shared_ptr<TaskGraph> graph) } // Compute the critical adjacency list + // That is, the set of critical nodes that can be reached + // from a given critical node, think of it as a subgraph TaskNodeAdjList critAdjList; - for (size_t i = 0; i < critNodes.size(); i++) - { - std::unordered_set<std::shared_ptr<TaskNode>> visitedNodes; - - // DFS for the dependencies - std::stack<std::shared_ptr<TaskNode>> nodeStack; - nodeStack.push(critNodes[i]); - while (!nodeStack.empty()) + auto computeCritList = + [&]() { - std::shared_ptr<TaskNode> currNode = nodeStack.top(); - nodeStack.pop(); + critAdjList.clear(); - // If you can reach one critical node from the other then they are adjacent - if (currNode->m_isCritical) + // For every critical node + for (size_t i = 0; i < critNodes.size(); i++) { - critAdjList[critNodes[i]].insert(currNode); - } + std::unordered_set<std::shared_ptr<TaskNode>> visitedNodes; - // Add children to stack if not yet visited - if (adjList.count(currNode) != 0) - { - const TaskNodeSet& outputNodes = adjList.at(currNode); - for (TaskNodeSet::const_iterator j = outputNodes.begin(); j != outputNodes.end(); j++) + // DFS for the dependencies (try to reach another critical) + std::stack<std::shared_ptr<TaskNode>> nodeStack; + nodeStack.push(critNodes[i]); + while (!nodeStack.empty()) { - std::shared_ptr<TaskNode> childNode = *j; - if (visitedNodes.count(childNode) == 0) + std::shared_ptr<TaskNode> currNode = nodeStack.top(); + nodeStack.pop(); + + // If you can reach one critical node from the other then they are adjacent + if (currNode->m_isCritical) { - visitedNodes.insert(childNode); - nodeStack.push(childNode); + critAdjList[critNodes[i]].insert(currNode); + } + + // Add children to stack if not yet visited + if (adjList.count(currNode) != 0) + { + const TaskNodeSet& outputNodes = adjList.at(currNode); + for (TaskNodeSet::const_iterator j = outputNodes.begin(); j != outputNodes.end(); j++) + { + std::shared_ptr<TaskNode> childNode = *j; + if (visitedNodes.count(childNode) == 0) + { + visitedNodes.insert(childNode); + nodeStack.push(childNode); + } + } } } } - } - } + }; + computeCritList(); // Now we know which critical nodes depend on each other (we are interested in those that aren't) + // Because if a critical node depends on another, then it must not be running in parallel to another + // critical node + // For every critical pair for (size_t i = 0; i < critNodes.size(); i++) { @@ -565,6 +581,8 @@ TaskGraph::resolveCriticalNodes(std::shared_ptr<TaskGraph> graph) { results->addEdge(srcNode, destNode); } + computeDepths(); + computeCritList(); } } } diff --git a/Source/Common/TaskGraph/imstkTaskGraphVizWriter.cpp b/Source/Common/TaskGraph/imstkTaskGraphVizWriter.cpp index 768b3b9508bb7eb495e0c61994f5955343ec2e43..51bbe8447ba7f4743518ff19d37d65c05c513391 100644 --- a/Source/Common/TaskGraph/imstkTaskGraphVizWriter.cpp +++ b/Source/Common/TaskGraph/imstkTaskGraphVizWriter.cpp @@ -120,7 +120,14 @@ TaskGraphVizWriter::write() } else { - file << " color=cornflowerblue"; + if (nodes[i]->m_isCritical) + { + file << " color=\"#8B2610\""; + } + else + { + file << " color=cornflowerblue"; + } } file << "];" << std::endl; } diff --git a/Source/Common/Testing/imstkTaskGraphTest.cpp b/Source/Common/Testing/imstkTaskGraphTest.cpp index ff4971e22d4ff7ca148c73ab5537428a1fe4cd9d..1970b257406978508449c69e87acd76ea40e4756 100644 --- a/Source/Common/Testing/imstkTaskGraphTest.cpp +++ b/Source/Common/Testing/imstkTaskGraphTest.cpp @@ -24,6 +24,8 @@ #include "imstkTaskGraph.h" +#include <array> + using namespace imstk; using testing::UnorderedElementsAre; using testing::ElementsAre; @@ -626,3 +628,107 @@ TEST(imstkTaskGraphTest, RemoveUnusedNodes) EXPECT_EQ(0, result->getNodes().size()); } } + +TEST(imstkTaskGraphTest, ResolveCriticalNodes0) +{ + /* + * A + * | + * B + * /|\ + * C D E (c, d, & e are critical) + * \| | + * F | + * |/ + * G + */ + + auto taskGraph = std::make_shared<TaskGraph>(); + + std::shared_ptr<TaskNode> nodeA = taskGraph->getSource(); + auto nodeB = std::make_shared<TaskNode>(); + auto nodeC = std::make_shared<TaskNode>([]() {}, "c", true); + auto nodeD = std::make_shared<TaskNode>([]() {}, "d", true); + auto nodeE = std::make_shared<TaskNode>([]() {}, "e", true); + auto nodeF = std::make_shared<TaskNode>(); + std::shared_ptr<TaskNode> nodeG = taskGraph->getSink(); + + taskGraph->addNodes({ nodeA, nodeB, nodeC, nodeD, nodeE, nodeF }); + + taskGraph->addEdge(nodeA, nodeB); + taskGraph->addEdge(nodeA, nodeE); + taskGraph->addEdge(nodeB, nodeC); + taskGraph->addEdge(nodeB, nodeD); + taskGraph->addEdge(nodeC, nodeF); + taskGraph->addEdge(nodeD, nodeF); + taskGraph->addEdge(nodeF, nodeG); + taskGraph->addEdge(nodeE, nodeG); + + taskGraph = TaskGraph::resolveCriticalNodes(taskGraph); + + std::array<std::shared_ptr<TaskNode>, 3> critNodes = { nodeC, nodeD, nodeE }; + + // Assert that C, D, & E are connected in some sort of chain + // Assert that two of the 3 nodes have critical inputs (1 should be head) + const TaskNodeAdjList& invAdjList = taskGraph->getInvAdjList(); + int critInputCount = 0; + for (int i = 0; i < 3; i++) + { + const TaskNodeSet& inputNodes = invAdjList.at(critNodes[i]); + bool critInputFound = false; + for (TaskNodeSet::const_iterator j = inputNodes.begin(); j != inputNodes.end(); j++) + { + if ((*j)->m_isCritical) + { + critInputFound = true; + } + } + critInputCount += static_cast<int>(critInputFound); + } + EXPECT_EQ(critInputCount, 2) << "Nodes C, D, & E should be connected in some sort of sequence"; +} + +TEST(imstkTaskGraphTest, ResolveCriticalNodes1) +{ + /* + * A + * / \ + * B C (b & c critical) + * \ / + * D + * / \ + * E F (e & f critical) + * \ / + * G + */ + + auto taskGraph = std::make_shared<TaskGraph>(); + + std::shared_ptr<TaskNode> nodeA = taskGraph->getSource(); + auto nodeB = std::make_shared<TaskNode>([]() {}, "b", true); + auto nodeC = std::make_shared<TaskNode>([]() {}, "c", true); + auto nodeD = std::make_shared<TaskNode>(); + auto nodeE = std::make_shared<TaskNode>([]() {}, "e", true); + auto nodeF = std::make_shared<TaskNode>([]() {}, "f", true); + std::shared_ptr<TaskNode> nodeG = taskGraph->getSink(); + + taskGraph->addNodes({ nodeA, nodeB, nodeC, nodeD, nodeE, nodeF }); + + taskGraph->addEdge(nodeA, nodeB); + taskGraph->addEdge(nodeA, nodeC); + taskGraph->addEdge(nodeB, nodeD); + taskGraph->addEdge(nodeC, nodeD); + taskGraph->addEdge(nodeD, nodeE); + taskGraph->addEdge(nodeD, nodeF); + taskGraph->addEdge(nodeE, nodeG); + taskGraph->addEdge(nodeF, nodeG); + + taskGraph = TaskGraph::resolveCriticalNodes(taskGraph); + + // There should now exist an edge between B->C, & E->F + // direction does not matter + EXPECT_TRUE(taskGraph->containsEdge(nodeB, nodeC) || taskGraph->containsEdge(nodeC, nodeB)) << + "There should exist an edge between B & C"; + EXPECT_TRUE(taskGraph->containsEdge(nodeE, nodeF) || taskGraph->containsEdge(nodeF, nodeE)) << + "There should exist an edge between E & F"; +} diff --git a/Source/Scene/imstkScene.cpp b/Source/Scene/imstkScene.cpp index e6b4e4b0292d5a865a019cb5ae2cbcb1c43be0fc..650d70d74ee1e1a23ca81214ba588996fa7e8b76 100644 --- a/Source/Scene/imstkScene.cpp +++ b/Source/Scene/imstkScene.cpp @@ -157,6 +157,11 @@ Scene::buildTaskGraph() } } + // Remove any possible unused nodes + m_taskGraph = TaskGraph::removeUnusedNodes(m_taskGraph); + // Resolve criticals across objects + m_taskGraph = TaskGraph::resolveCriticalNodes(m_taskGraph); + #ifdef IMSTK_USE_PHYSX // Gather all the physX rigid bodies std::list<std::shared_ptr<SceneObject>> rigidBodies; @@ -205,13 +210,19 @@ Scene::initTaskGraph() m_taskGraphController = std::make_shared<SequentialTaskGraphController>(); } - // Reduce the graph, removing nonfunctional nodes, and redundant edges - m_taskGraph = TaskGraph::removeUnusedNodes(m_taskGraph); if (TaskGraph::isCyclic(m_taskGraph)) { + if (m_config->writeTaskGraph) + { + TaskGraphVizWriter writer; + writer.setInput(m_taskGraph); + writer.setFileName("sceneTaskGraph.svg"); + writer.write(); + } LOG(FATAL) << "Scene TaskGraph is cyclic, cannot proceed"; return; } + // Clean up graph if user wants if (m_config->graphReductionEnabled) { m_taskGraph = TaskGraph::reduce(m_taskGraph);