diff --git a/agent/agent.ts b/agent/agent.ts index 222cd33..c53eb69 100644 --- a/agent/agent.ts +++ b/agent/agent.ts @@ -1,30 +1,58 @@ import { END, START, StateGraph } from "@langchain/langgraph"; import { MessagesState } from "./state"; -import { toolNode } from "./nodes/tool"; +import { createToolNode } from "./nodes/tool"; import { createToolConditional } from "./conditionals/tool_end"; import { normalizationSetup } from "./nodes/normalizationSetup"; -import { dummyNormalisationModel } from "./nodes/dummyNormalisationModel"; -import { dummyTriggerEventModel } from "./nodes/dummyTriggerEventModel"; +import { arithmeticToolsByName } from "./tools/arithmetic" +import { createDummyModelNode } from "./nodes/dummyModel"; +import { verificationSetup } from "./nodes/verificationSetup"; +import { dummyRagasMetrics } from "./nodes/dummyRagasMetrics"; +import { produceRanking } from "./nodes/produceRanking"; + +const triggerEventToolNode = createToolNode(arithmeticToolsByName); +const verificationToolNode = createToolNode(arithmeticToolsByName); + +const dummyTriggerEventModel = createDummyModelNode("Trigger Events of"); +const dummyNormalisationModel = createDummyModelNode("Normalised"); +const dummyVerificationModel = createDummyModelNode("verification of"); + +const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name); +const verificationToolConditional = createToolConditional("verificationToolNode", produceRanking.name); + -const triggerEventToolConditional = createToolConditional(toolNode.name, END) const agent = new StateGraph(MessagesState) //NODES - .addNode("toolNode", toolNode) + .addNode(normalizationSetup.name, normalizationSetup) - .addNode(dummyNormalisationModel.name, dummyNormalisationModel) - .addNode(dummyTriggerEventModel.name, dummyTriggerEventModel) + .addNode("dummyNormalisationModel", dummyNormalisationModel) + + .addNode("triggerEventToolNode", triggerEventToolNode) + .addNode("dummyTriggerEventModel", dummyTriggerEventModel) + + .addNode(verificationSetup.name, verificationSetup) + .addNode("dummyVerificationModel", dummyVerificationModel) + .addNode(dummyRagasMetrics.name, dummyRagasMetrics) + .addNode("verificationToolNode", verificationToolNode) + .addNode(produceRanking.name, produceRanking) .addEdge(START, normalizationSetup.name) - .addEdge(normalizationSetup.name, dummyNormalisationModel.name) - .addEdge(dummyNormalisationModel.name, dummyTriggerEventModel.name) + .addEdge(normalizationSetup.name, "dummyNormalisationModel") + .addEdge("dummyNormalisationModel", "dummyTriggerEventModel") // @ts-expect-error - .addConditionalEdges(dummyTriggerEventModel.name, triggerEventToolConditional, [toolNode.name, END]) + .addConditionalEdges("dummyTriggerEventModel", triggerEventToolConditional, ["triggerEventToolNode", verificationSetup.name]) + .addEdge("triggerEventToolNode", "dummyTriggerEventModel") - .addEdge(toolNode.name, dummyTriggerEventModel.name) + .addEdge(verificationSetup.name, "dummyVerificationModel") + .addEdge(verificationSetup.name, dummyRagasMetrics.name) + + // @ts-expect-error + .addConditionalEdges("dummyVerificationModel", verificationToolConditional, ["verificationToolNode", produceRanking.name]) + .addEdge("verificationToolNode", "dummyVerificationModel") - .addEdge(dummyTriggerEventModel.name, END) + .addEdge(dummyRagasMetrics.name, produceRanking.name) + .compile(); export {agent} \ No newline at end of file diff --git a/agent/nodes/dummyModel.ts b/agent/nodes/dummyModel.ts new file mode 100644 index 0000000..299ecc9 --- /dev/null +++ b/agent/nodes/dummyModel.ts @@ -0,0 +1,13 @@ +import { GraphNode } from "@langchain/langgraph"; +import { MessagesState } from "../state"; +import { AIMessage } from "@langchain/core/messages"; + +export function createDummyModelNode(addition): GraphNode { + return async (state) => { + //TODO: call AI model with collected data + + return { + messages: [new AIMessage(addition + " : " + state.messages.at(-1)?.content)] + }; + }; +} \ No newline at end of file diff --git a/agent/nodes/dummyNormalisationModel.ts b/agent/nodes/dummyNormalisationModel.ts deleted file mode 100644 index 2c22305..0000000 --- a/agent/nodes/dummyNormalisationModel.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { GraphNode } from "@langchain/langgraph"; -import { MessagesState } from "../state"; -import { AIMessage, HumanMessage } from "@langchain/core/messages"; - -export const dummyNormalisationModel: GraphNode = async (state) => { - //TODO: call AI model with collected data - - return { - messages: [ new AIMessage(state.messages.at(-1)?.content + " Processed")] - }; -}; \ No newline at end of file diff --git a/agent/nodes/dummyTriggerEventModel.ts b/agent/nodes/dummyTriggerEventModel.ts deleted file mode 100644 index 8618497..0000000 --- a/agent/nodes/dummyTriggerEventModel.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { GraphNode } from "@langchain/langgraph"; -import { MessagesState } from "../state"; -import { AIMessage, HumanMessage } from "@langchain/core/messages"; - -export const dummyTriggerEventModel: GraphNode = async (state) => { - //TODO: call AI model with collected data - - return { - messages: [ new AIMessage("Trigger events of: " + state.messages.at(-1)?.content)] - }; -}; \ No newline at end of file diff --git a/agent/nodes/dummyVerificationModel.ts b/agent/nodes/dummyVerificationModel.ts deleted file mode 100644 index fcd0250..0000000 --- a/agent/nodes/dummyVerificationModel.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { GraphNode } from "@langchain/langgraph"; -import { MessagesState } from "../state"; -import { AIMessage, HumanMessage } from "@langchain/core/messages"; - -export const dummyVerificationModel: GraphNode = async (state) => { - //TODO: call AI model with collected data - - return { - messages: [ new AIMessage("Verified : " + state.messages.at(-1)?.content)] - }; -}; \ No newline at end of file diff --git a/agent/nodes/produceRanking.ts b/agent/nodes/produceRanking.ts new file mode 100644 index 0000000..5665934 --- /dev/null +++ b/agent/nodes/produceRanking.ts @@ -0,0 +1,9 @@ +import { GraphNode } from "@langchain/langgraph"; +import { MessagesState } from "../state"; +import { AIMessage, HumanMessage } from "@langchain/core/messages"; + +export const produceRanking: GraphNode = async (state) => { + //TODO: produce ranking here + + return { messages: [ new AIMessage(state.messages?.length.toString() ?? "0")] }; +}; \ No newline at end of file diff --git a/agent/nodes/tool.ts b/agent/nodes/tool.ts index 55f3c6e..e7cb4fe 100644 --- a/agent/nodes/tool.ts +++ b/agent/nodes/tool.ts @@ -1,25 +1,26 @@ import { AIMessage, ToolMessage } from "@langchain/core/messages"; import { GraphNode } from "@langchain/langgraph"; import { MessagesState } from "../state"; -import { arithmeticToolsByName } from "../tools/arithmetic"; -export const toolNode: GraphNode = async (state) => { - const lastMessage = state.messages.at(-1); +export function createToolNode(tools): GraphNode { + return async (state) => { + const lastMessage = state.messages.at(-1); - //STARTTEMP - return {messages: [new AIMessage("yeman")]} - //ENDTEMP - - if (lastMessage == null || !AIMessage.isInstance(lastMessage)) { - return { messages: [] }; - } + //STARTTEMP + return {messages: [new AIMessage("yeman")]} + //ENDTEMP - const result: ToolMessage[] = []; - for (const toolCall of lastMessage.tool_calls ?? []) { - const tool = arithmeticToolsByName[toolCall.name]; - const observation = await tool.invoke(toolCall); - result.push(observation); - } + if (lastMessage == null || !AIMessage.isInstance(lastMessage)) { + return { messages: [] }; + } - return { messages: result }; -}; \ No newline at end of file + const result: ToolMessage[] = []; + for (const toolCall of (lastMessage as AIMessage).tool_calls ?? []) { + const tool = tools[toolCall.name]; + const observation = await tool.invoke(toolCall); + result.push(observation); + } + + return { messages: result }; + }; +} \ No newline at end of file diff --git a/agent/tools/arithmetic.ts b/agent/tools/arithmetic.ts index a695c81..f7aadf5 100644 --- a/agent/tools/arithmetic.ts +++ b/agent/tools/arithmetic.ts @@ -37,4 +37,4 @@ export const arithmeticToolsByName = { [divide.name]: divide, }; -export const arithmeticTools = Object.values(arithmeticToolsByName); \ No newline at end of file +//const arithmeticTools = Object.values(arithmeticToolsByName); \ No newline at end of file