feat: 支持图谱向量化

This commit is contained in:
haokai
2025-04-03 17:49:55 +08:00
parent 6be76ebc33
commit f8b5c43166
47 changed files with 2294 additions and 128 deletions

View File

@@ -0,0 +1,50 @@
package com.shuwen.data.entity.manage.api.model.chat;
import com.alibaba.fastjson.annotation.JSONField;
import lombok.Getter;
import lombok.Setter;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 11:57
*/
@Getter
@Setter
public class EdgeVectorInfo extends GraphVectorInfo {
/**
* 边ID
*/
@JSONField(name = "edge_id")
private String edgeId;
/**
* 关系名
*/
@JSONField(name = "rel_name")
private String relName;
/**
* 出实体类型
*/
@JSONField(name = "label_out")
private String labelOut;
/**
* 出实体ID
*/
@JSONField(name = "entity_id_out")
private String entityIdOut;
/**
* 入实体类型
*/
@JSONField(name = "label_in")
private String labelIn;
/**
* 入实体ID
*/
@JSONField(name = "entity_id_in")
private String entityIdIn;
public EdgeVectorInfo() {
super(VectorDataType.EDGE);
}
}

View File

@@ -0,0 +1,44 @@
package com.shuwen.data.entity.manage.api.model.chat;
import com.alibaba.fastjson.annotation.JSONField;
import lombok.Getter;
import lombok.Setter;
import java.math.BigDecimal;
import java.util.List;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 11:55
*/
@Getter
@Setter
public abstract class GraphVectorInfo {
/**
* 文本片段
*/
private String content;
/**
* 文本向量
*/
private List<BigDecimal> feature;
/**
* 图谱名称
*/
private String graph;
/**
* 向量类型
*/
@JSONField(name = "vector_data_type")
private VectorDataType vectorDataType;
public GraphVectorInfo(VectorDataType vectorDataType) {
this.vectorDataType = vectorDataType;
}
public enum VectorDataType {
VERTEX, EDGE
}
}

View File

@@ -1,9 +1,11 @@
package com.shuwen.data.entity.manage.api.model.chat;
import com.google.common.collect.Sets;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
import java.util.Set;
/**
* Project: entity-manage
@@ -46,4 +48,15 @@ public class RelationPath {
* 路径信息
*/
private List<EdgePathElement> relationList;
/**
* 边ID集合
*/
private Set<String> edgeIds;
public void addEdgeId(String edgeId) {
if (edgeIds == null) {
edgeIds = Sets.newHashSet();
}
edgeIds.add(edgeId);
}
}

View File

@@ -27,6 +27,10 @@ public class RetrievalInfo {
private Map<String, JSONObject> fullTextRecallEntityMap;
private Map<String, Map<String, Object>> vectorRecallEntityMap;
private List<RelationPath> vectorRecallRelationPathList;
private String subjectId;
//主语信息,只有关系类问题才会填充

View File

@@ -0,0 +1,62 @@
package com.shuwen.data.entity.manage.api.model.chat;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Maps;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
import java.util.Map;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/4/2 15:04
*/
@Getter
@Setter
public class VectorRetrievalInfo {
/**
* 向量召回实体结果
*/
private Map<String, Map<String, Object>> vectorRecallEntityMap;
/**
* 向量召回关系结果
*/
private Map<String, RelationPath> vectorRecallRelationPathMap;
/**
* 排序语料
*/
private List<RankData> rankDataList;
/**
* 排序分数
*/
private List<Float> rankScoreList;
/**
* 调用消耗
*/
private JSONObject usage = new JSONObject();
public void putUsage(String key, long used) {
usage.put(key, used);
}
public void addEntity(String entityId, Map<String, Object> entity) {
if (vectorRecallEntityMap == null) {
vectorRecallEntityMap = Maps.newLinkedHashMap();
}
if (!vectorRecallEntityMap.containsKey(entityId)) {
vectorRecallEntityMap.put(entityId, entity);
}
}
public void addRelationPath(String edgeId, RelationPath path) {
if (vectorRecallRelationPathMap == null) {
vectorRecallRelationPathMap = Maps.newLinkedHashMap();
}
if (!vectorRecallRelationPathMap.containsKey(edgeId)) {
vectorRecallRelationPathMap.put(edgeId, path);
}
}
}

View File

@@ -0,0 +1,29 @@
package com.shuwen.data.entity.manage.api.model.chat;
import com.alibaba.fastjson.annotation.JSONField;
import lombok.Getter;
import lombok.Setter;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 11:57
*/
@Getter
@Setter
public class VertexVectorInfo extends GraphVectorInfo {
/**
* 实体类型
*/
private String label;
/**
* 实体ID
*/
@JSONField(name = "entity_id")
private String entityId;
public VertexVectorInfo() {
super(VectorDataType.VERTEX);
}
}

View File

@@ -32,6 +32,7 @@ public enum ReturnCodeEnum {
ALGO_ERROR(2003, "algo_error:fusion"),
ONTOLOGY_ERROR(2004, "ontology_error"),
PARTIAL_FAIL(2005, "partial_fail"),
HTTP_ERROR(2006, "http_error"),
//3*实体错误
ENTITY_CONFLICT(3001, "entity_conflict"),

View File

@@ -14,7 +14,7 @@ import com.shuwen.data.entity.manage.api.model.enums.ReturnCodeEnum;
* @author yangcongyu@shuwen.com
* @version 1.0
*/
public class AlgoServiceException extends Exception {
public class AlgoServiceException extends RuntimeException {
private ReturnCodeEnum code = ReturnCodeEnum.ALGO_ERROR;
public AlgoServiceException() {

View File

@@ -0,0 +1,30 @@
package com.shuwen.data.entity.manage.api.model.exception;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/19 11:51
*/
public class QueueServiceException extends RuntimeException {
private QueueTypeEnum type;
public QueueServiceException(QueueTypeEnum type, String msg) {
super(msg);
this.type = type;
}
public QueueServiceException(QueueTypeEnum type, String msg, Throwable e) {
super(msg, e);
this.type = type;
}
public QueueTypeEnum getType() {
return type;
}
public enum QueueTypeEnum {
MQ, REDIS
}
}

View File

@@ -0,0 +1,20 @@
package com.shuwen.data.entity.manage.api.model.graph.dto;
import com.shuwen.data.entity.manage.api.model.base.AbstractRequest;
import lombok.Getter;
import lombok.Setter;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/4/1 10:59
*/
@Getter
@Setter
public class EdgeVectorCrudDto extends AbstractRequest {
/**
* 边ID
*/
private String edgeId;
}

View File

@@ -0,0 +1,60 @@
package com.shuwen.data.entity.manage.api.model.graph.dto;
import com.shuwen.data.entity.manage.api.model.base.AbstractRequest;
import lombok.Getter;
import lombok.Setter;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/4/2 15:39
*/
@Getter
@Setter
public class VectorSearchDto extends AbstractRequest {
/**
* 查询语句
*/
private String query;
/**
* 实体类型
*/
private String label;
/**
* 实体ID
*/
private String entityId;
/**
* 边ID
*/
private String edgeId;
/**
* 关系名
*/
private String relName;
/**
* 出实体类型
*/
private String labelOut;
/**
* 出实体ID
*/
private String entityIdOut;
/**
* 入实体类型
*/
private String labelIn;
/**
* 入实体ID
*/
private String entityIdIn;
/**
* 图谱名称
*/
private String graph;
/**
* 返回数量
*/
private int size;
}

View File

@@ -0,0 +1,24 @@
package com.shuwen.data.entity.manage.api.model.graph.dto;
import com.shuwen.data.entity.manage.api.model.base.AbstractRequest;
import lombok.Getter;
import lombok.Setter;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 17:17
*/
@Getter
@Setter
public class VertexVectorCrudDto extends AbstractRequest {
/**
* 实体类型
*/
private String label;
/**
* 实体ID
*/
private String entityId;
}

View File

@@ -0,0 +1,34 @@
package com.shuwen.data.entity.manage.api.model.graph.result;
import com.shuwen.data.entity.manage.api.model.chat.EdgeVectorInfo;
import lombok.Getter;
import lombok.Setter;
import java.util.Map;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/4/3 14:50
*/
@Getter
@Setter
public class EdgeVectorResult extends EdgeVectorInfo {
/**
* 分数
*/
private double score;
/**
* 边属性
*/
private Map<String, Object> edge;
/**
* source节点属性
*/
private Map<String, Object> source;
/**
* target节点属性
*/
private Map<String, Object> target;
}

View File

@@ -0,0 +1,28 @@
package com.shuwen.data.entity.manage.api.model.graph.result;
import com.alibaba.fastjson.annotation.JSONField;
import com.shuwen.data.entity.manage.api.model.chat.VertexVectorInfo;
import lombok.Getter;
import lombok.Setter;
import java.util.Map;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/4/3 14:49
*/
@Getter
@Setter
public class VertexVectorResult extends VertexVectorInfo {
/**
* 分数
*/
@JSONField(name = "_score")
private double score;
/**
* 实体详情
*/
private Map<String, Object> entity;
}

View File

@@ -0,0 +1,59 @@
package com.shuwen.data.entity.manage.api.service;
import com.shuwen.data.entity.manage.api.model.chat.GraphVectorInfo;
import com.shuwen.data.entity.manage.api.model.graph.dto.EdgeVectorCrudDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VectorSearchDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VertexVectorCrudDto;
import java.util.List;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/19 15:05
*/
public interface IGraphVectorService {
/**
* 构建Vertex向量
* @param dto 请求
*/
void create(VertexVectorCrudDto dto);
/**
* 删除Vertex向量
* @param dto 请求
*/
void delete(VertexVectorCrudDto dto);
/**
* 重建Vertex向量
* @param dto 请求
*/
void rebuild(VertexVectorCrudDto dto);
/**
* 构建Edge向量
* @param dto 请求
*/
void create(EdgeVectorCrudDto dto);
/**
* 删除Edge向量
* @param dto 请求
*/
void delete(EdgeVectorCrudDto dto);
/**
* 重建Edge向量
* @param dto 请求
*/
void rebuild(EdgeVectorCrudDto dto);
/**
* 向量搜索
* @param dto 请求
* @return 结果
*/
List<GraphVectorInfo> search(VectorSearchDto dto);
}

View File

@@ -0,0 +1,76 @@
package com.shuwen.data.entity.manage.controller.graph;
import com.shuwen.data.entity.manage.api.model.base.RespDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.EdgeVectorCrudDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VectorSearchDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VertexVectorCrudDto;
import com.shuwen.data.entity.manage.api.service.IGraphVectorService;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;
import javax.annotation.Resource;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 17:15
*/
@Controller
@RequestMapping("/graph/vector")
public class GraphVectorController {
@Resource
private IGraphVectorService graphVectorService;
@PostMapping("/vertex/create")
@ResponseBody
public Object createVertexVector(@RequestBody VertexVectorCrudDto reqDto) {
graphVectorService.create(reqDto);
return RespDto.succeed(reqDto.getRequestId());
}
@PostMapping("/vertex/delete")
@ResponseBody
public Object deleteVertexVector(@RequestBody VertexVectorCrudDto reqDto) {
graphVectorService.delete(reqDto);
return RespDto.succeed(reqDto.getRequestId());
}
@PostMapping("/vertex/rebuild")
@ResponseBody
public Object rebuildVertexVector(@RequestBody VertexVectorCrudDto reqDto) {
graphVectorService.rebuild(reqDto);
return RespDto.succeed(reqDto.getRequestId());
}
@PostMapping("/edge/create")
@ResponseBody
public Object createEdgeVector(@RequestBody EdgeVectorCrudDto reqDto) {
graphVectorService.create(reqDto);
return RespDto.succeed(reqDto.getRequestId());
}
@PostMapping("/edge/delete")
@ResponseBody
public Object deleteEdgeVector(@RequestBody EdgeVectorCrudDto reqDto) {
graphVectorService.delete(reqDto);
return RespDto.succeed(reqDto.getRequestId());
}
@PostMapping("/edge/rebuild")
@ResponseBody
public Object rebuildEdgeVector(@RequestBody EdgeVectorCrudDto reqDto) {
graphVectorService.rebuild(reqDto);
return RespDto.succeed(reqDto.getRequestId());
}
@PostMapping("/search")
@ResponseBody
public Object searchVector(@RequestBody VectorSearchDto reqDto) {
return graphVectorService.search(reqDto);
}
}

View File

@@ -50,6 +50,8 @@ llm.ernie.model=ernie-bot
llm.azure-gpt.url=http://172.22.5.182/v1/chat/completions
llm.azure-gpt.model=gpt-3.5-turbo-16k
llm.wenlv.url=http://183.134.214.162:18000/test/wenlv/chatbot_subtask
llm.deepseek.url=https://ark.cn-beijing.volces.com/api/v3/chat/completions
llm.deepseek.model=deepseek-v3-250324
llm.plain-timeout=12000
llm.stream-timeout=6000
@@ -65,11 +67,18 @@ thread.graph-get.coreSize=100
thread.graph-get.maxSize=200
thread.graph-get.capacity=10000
thread.graph-get.timeout=120
thread.full-text.coreSize=50
thread.full-text.maxSize=100
thread.full-text.capacity=10000
thread.full-text.timeout=120
thread.multi-route-recall.coreSize=100
thread.multi-route-recall.maxSize=200
thread.multi-route-recall.capacity=10000
thread.multi-route-recall.timeout=120
thread.label-full-text.coreSize=100
thread.label-full-text.maxSize=200
thread.label-full-text.capacity=10000
thread.label-full-text.timeout=120
ons.access-channel=CLOUD
ons.namesrvAddr=http://onsaddr.mq-internet-access.mq-internet.aliyuncs.com:80
ons.topic=TEST_GROOT_GRAPH_DATA
ons.vector.group=GID_TEST_GROOT_FACADE_VECTOR_TCP
ons.vector.tag=VERTEX||EDGE
ons.vector.thread-num=1

View File

@@ -62,6 +62,17 @@
</dependency>
<!-- aliyun end -->
<!-- service end -->
<dependency>
<groupId>org.apache.rocketmq</groupId>
<artifactId>rocketmq-client</artifactId>
</dependency>
<dependency>
<groupId>org.apache.rocketmq</groupId>
<artifactId>rocketmq-acl</artifactId>
</dependency>
<!-- service end -->
<!-- spring start -->
<dependency>
<groupId>org.springframework.boot</groupId>

View File

@@ -3,6 +3,7 @@ package com.shuwen.data.entity.manage.service.chat.extract;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import com.shuwen.data.entity.manage.api.model.chat.EdgePathElement;
import com.shuwen.data.entity.manage.api.model.chat.RelationPath;
import com.shuwen.data.entity.manage.api.model.chat.RetrievalInfo;
@@ -111,7 +112,8 @@ public class RelationInfoExtract {
relationPath.setCorpus(corpus);
relationPath.setFillInfo(fillObjectInfo);
relationPath.setFillEntityId(corpusObjectId);
relationPath.setRelationList(Arrays.asList(getEdgePathElement(relName, source, target)));
relationPath.setRelationList(Lists.newArrayList(getEdgePathElement(relName, source, target)));
relationPath.addEdgeId(edge.getEdgeId());
retrievalInfo.addRelation(relationPath);
edgeIdSet.add(key);
@@ -129,6 +131,7 @@ public class RelationInfoExtract {
List<String> pathSentences = new ArrayList<>();
List<EdgePathElement> pathElements = new ArrayList<>();
String corpusObjectId = null;
Set<String> edgeIds = Sets.newSet();
for (int i = 0; i < path.getPath().size(); i++) {
PathElement element = path.getPath().get(i);
if (i == 0) {
@@ -158,6 +161,7 @@ public class RelationInfoExtract {
pathSentences.add(relInfo);
}
pathElements.add(getEdgePathElement(edge.getRelName(), source, target));
edgeIds.add(edge.getEdgeId());
}
retrievalInfo.addEntity(endDetail);
@@ -166,6 +170,7 @@ public class RelationInfoExtract {
relationPath.setFillInfo(true);
relationPath.setFillEntityId(corpusObjectId);
relationPath.setRelationList(pathElements);
relationPath.setEdgeIds(edgeIds);
retrievalInfo.addRelation(relationPath);
}
}

View File

@@ -19,6 +19,8 @@ public class ChatOptionUtils {
private static final String KEY_FULL_TEXT_RECALL = "fullTextRecall";
private static final String KEY_VECTOR_RECALL = "vectorRecall";
public static Boolean alwaysFillSubject(JSONObject option) {
return option != null && BooleanUtils.isTrue(option.getBoolean(KEY_FILL_SUBJECT));
}
@@ -26,4 +28,8 @@ public class ChatOptionUtils {
public static boolean fullTextRecall(JSONObject option) {
return option != null && BooleanUtils.isTrue(option.getBoolean(KEY_FULL_TEXT_RECALL));
}
public static boolean vectorRecall(JSONObject option) {
return option != null && BooleanUtils.isTrue(option.getBoolean(KEY_VECTOR_RECALL));
}
}

View File

@@ -0,0 +1,39 @@
package com.shuwen.data.entity.manage.service.config;
import lombok.Getter;
import lombok.Setter;
import org.apache.rocketmq.client.AccessChannel;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/19 11:44
*/
@Getter
@Setter
@Component
@ConfigurationProperties("ons")
public class OnsConfig {
private AccessChannel accessChannel;
private String namesrvAddr;
private String topic;
private ConsumerConfig vector;
@Getter
@Setter
public static class ConsumerConfig {
private String group;
private String tag;
private int threadNum;
}
}

View File

@@ -29,9 +29,9 @@ public class ThreadConfig {
*/
private ThreadInfo graphGet;
/**
* 图谱全文召回线程池信息
* 多路召回线程池信息
*/
private ThreadInfo fullText;
private ThreadInfo multiRouteRecall;
/**
* 不同label模式全文召回线程池信息
*/

View File

@@ -62,15 +62,15 @@ public class ThreadPoolConfig {
return executor;
}
@Bean(name = "fullTextExecutor")
public Executor fullTextExecutor() {
@Bean(name = "multiRouteRecallExecutor")
public Executor multiRouteRecallExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
ThreadConfig.ThreadInfo threadInfo = threadConfig.getFullText();
ThreadConfig.ThreadInfo threadInfo = threadConfig.getMultiRouteRecall();
executor.setCorePoolSize(threadInfo.getCoreSize());
executor.setMaxPoolSize(threadInfo.getMaxSize());
executor.setQueueCapacity(threadInfo.getCapacity());
executor.setKeepAliveSeconds(threadInfo.getTimeout());
executor.setThreadNamePrefix("full-text-");
executor.setThreadNamePrefix("multi-route-recall-");
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
executor.setWaitForTasksToCompleteOnShutdown(true);
return executor;

View File

@@ -0,0 +1,86 @@
package com.shuwen.data.entity.manage.service.external.mq;
import com.shuwen.data.entity.manage.service.config.OnsConfig;
import com.shuwen.data.entity.manage.service.config.OnsConfig.ConsumerConfig;
import com.shuwen.data.entity.manage.service.external.mq.listener.GenerateVectorListener;
import com.shuwen.ops.shaman.configmap.util.ConfigMapUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.rocketmq.acl.common.AclClientRPCHook;
import org.apache.rocketmq.acl.common.SessionCredentials;
import org.apache.rocketmq.client.AccessChannel;
import org.apache.rocketmq.client.consumer.DefaultMQPushConsumer;
import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.remoting.RPCHook;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
/**
* Project: groot-data-bank
* Description:
* Author: Kenn
* Create: 2024/7/25 14:47
*/
@Slf4j
@Component
public class ConsumerHandler {
@Resource
private OnsConfig onsConfig;
@Resource
private GenerateVectorListener generateVectorListener;
private DefaultMQPushConsumer vectorConsumer;
@PostConstruct
public void init() {
if (onsConfig.getAccessChannel() == null || onsConfig.getAccessChannel() == AccessChannel.CLOUD) {
ConfigMapUtils.processAkSkChange(s -> {
log.info("aksk is changed, need to reconstruct ons consumer");
if (vectorConsumer != null) {
vectorConsumer.shutdown();
}
try {
vectorConsumer = buildConsumer(onsConfig.getVector(), generateVectorListener);
vectorConsumer.start();
} catch (MQClientException e) {
throw new RuntimeException(e);
}
});
} else {
try {
vectorConsumer = buildConsumer(onsConfig.getVector(), generateVectorListener);
vectorConsumer.start();
} catch (MQClientException e) {
throw new RuntimeException(e);
}
}
}
public DefaultMQPushConsumer buildConsumer(ConsumerConfig config, MessageListenerConcurrently listener) throws MQClientException {
DefaultMQPushConsumer consumer;
if (onsConfig.getAccessChannel() == null || onsConfig.getAccessChannel() == AccessChannel.CLOUD) {
RPCHook hook = new AclClientRPCHook(new SessionCredentials(ConfigMapUtils.getAk(), ConfigMapUtils.getSk()));
consumer = new DefaultMQPushConsumer(hook);
consumer.setAccessChannel(AccessChannel.CLOUD);
consumer.setNamesrvAddr(onsConfig.getNamesrvAddr());
consumer.setConsumerGroup(config.getGroup());
consumer.subscribe(onsConfig.getTopic(), config.getTag());
consumer.registerMessageListener(listener);
consumer.setConsumeThreadMax(config.getThreadNum());
consumer.setConsumeThreadMin(config.getThreadNum());
} else {
consumer = new DefaultMQPushConsumer();
consumer.setNamesrvAddr(onsConfig.getNamesrvAddr());
consumer.setConsumerGroup(config.getGroup());
consumer.subscribe(onsConfig.getTopic(), config.getTag());
consumer.registerMessageListener(listener);
consumer.setConsumeThreadMax(config.getThreadNum());
consumer.setConsumeThreadMin(config.getThreadNum());
}
return consumer;
}
}

View File

@@ -0,0 +1,106 @@
package com.shuwen.data.entity.manage.service.external.mq;
import com.alibaba.fastjson.JSONObject;
import com.shuwen.data.entity.manage.api.model.exception.QueueServiceException;
import com.shuwen.data.entity.manage.service.config.OnsConfig;
import com.shuwen.data.entity.manage.service.external.ontology.OntologyCacheManager;
import com.shuwen.ops.shaman.configmap.util.ConfigMapUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.acl.common.AclClientRPCHook;
import org.apache.rocketmq.acl.common.SessionCredentials;
import org.apache.rocketmq.client.AccessChannel;
import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.client.producer.DefaultMQProducer;
import org.apache.rocketmq.common.message.Message;
import org.apache.rocketmq.remoting.RPCHook;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/19 11:49
*/
@Slf4j
@Component
public class ProducerHandler {
private DefaultMQProducer producer;
@Resource
private OnsConfig onsConfig;
@PostConstruct
public void init() {
if (onsConfig.getAccessChannel() == null || onsConfig.getAccessChannel() == AccessChannel.CLOUD) {
ConfigMapUtils.processAkSkChange(s -> {
log.info("aksk is changed, need to reconstruct ons producer");
if (producer != null) {
producer.shutdown();
}
RPCHook hook = new AclClientRPCHook(new SessionCredentials(ConfigMapUtils.getAk(), ConfigMapUtils.getSk()));
producer = new DefaultMQProducer(hook);
producer.setAccessChannel(AccessChannel.CLOUD);
producer.setNamesrvAddr(onsConfig.getNamesrvAddr());
producer.setProducerGroup(onsConfig.getVector().getGroup());
try {
producer.start();
} catch (MQClientException e) {
throw new RuntimeException(e);
}
});
} else {
producer = new DefaultMQProducer();
producer.setNamesrvAddr(onsConfig.getNamesrvAddr());
producer.setProducerGroup(onsConfig.getVector().getGroup());
try {
producer.start();
} catch (MQClientException e) {
throw new RuntimeException(e);
}
}
}
public void sendVertex(String entityId, String label, String graph) {
if (!OntologyCacheManager.getInstance().getGraphConf(graph).isVectorization()) {
return;
}
JSONObject data = new JSONObject();
data.put("type", "VERTEX");
data.put("graph", graph);
data.put("entityId", entityId);
data.put("label", label);
send("VERTEX", entityId, data.toJSONString());
}
public void sendEdge(String edgeId, String graph) {
if (!OntologyCacheManager.getInstance().getGraphConf(graph).isVectorization()) {
return;
}
JSONObject data = new JSONObject();
data.put("type", "VERTEX");
data.put("graph", graph);
data.put("edgeId", edgeId);
send("EDGE", edgeId, data.toJSONString());
}
private void send(String tag, String key, String data) {
if (StringUtils.isBlank(data)) {
return;
}
try {
Message msg = new Message(onsConfig.getTopic(), tag, data.getBytes(StandardCharsets.UTF_8));
if (StringUtils.isNotEmpty(key)) {
msg.setKeys(key);
}
producer.send(msg);
} catch (Exception e) {
throw new QueueServiceException(QueueServiceException.QueueTypeEnum.MQ, "发送MQ消息失败", e);
}
}
}

View File

@@ -0,0 +1,83 @@
package com.shuwen.data.entity.manage.service.external.mq.listener;
import com.alibaba.fastjson.JSONObject;
import com.shuwen.data.entity.manage.api.model.graph.dto.EdgeVectorCrudDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VertexVectorCrudDto;
import com.shuwen.data.entity.manage.api.service.IGraphVectorService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
import org.apache.rocketmq.common.message.MessageExt;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
import java.util.List;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/19 11:56
*/
@Slf4j
@Component
public class GenerateVectorListener implements MessageListenerConcurrently {
@Resource
private IGraphVectorService graphVectorService;
@Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> list, ConsumeConcurrentlyContext consumeConcurrentlyContext) {
for (MessageExt messageExt : list) {
if (StringUtils.equalsAnyIgnoreCase(messageExt.getTags(), "VERTEX", "EDGE")) {
try {
consume(messageExt);
} catch (Exception e) {
log.error("consumer message error, msg key: {}", messageExt.getKeys(), e);
}
} else {
log.error("unknown tag: {}, msg id : {}", messageExt.getTags(), messageExt.getMsgId());
}
}
return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
}
private void consume(MessageExt ext) {
String body = new String(ext.getBody(), StandardCharsets.UTF_8);
JSONObject msg = JSONObject.parseObject(body);
String type = msg.getString("type");
String graph = msg.getString("graph");
if (type.contains("VERTEX")) {
String entityId = msg.getString("entityId");
String label = msg.getString("label");
VertexVectorCrudDto dto = wrap(entityId, label, graph);
graphVectorService.rebuild(dto);
} else if (type.contains("EDGE")) {
String edgeId = msg.getString("edgeId");
EdgeVectorCrudDto dto = wrap(edgeId, graph);
graphVectorService.rebuild(dto);
}
}
private VertexVectorCrudDto wrap(String entityId, String label, String graph) {
VertexVectorCrudDto dto = new VertexVectorCrudDto();
dto.setAppId("internal");
dto.setUid("groot");
dto.setEntityId(entityId);
dto.setLabel(label);
dto.setGraph(graph);
return dto;
}
private EdgeVectorCrudDto wrap(String edgeId, String graph) {
EdgeVectorCrudDto dto = new EdgeVectorCrudDto();
dto.setAppId("internal");
dto.setUid("groot");
dto.setEdgeId(edgeId);
dto.setGraph(graph);
return dto;
}
}

View File

@@ -7,14 +7,9 @@ import com.google.common.collect.Lists;
import com.shuwen.data.entity.manage.api.model.chat.RankData;
import com.shuwen.data.entity.manage.api.model.enums.ReturnCodeEnum;
import com.shuwen.data.entity.manage.api.model.exception.GraphServiceException;
import com.shuwen.data.entity.manage.service.config.OkHttpConfig;
import com.shuwen.data.entity.manage.service.http.HttpHandler;
import com.shuwen.data.entity.manage.service.log.ThirdLog;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@@ -35,7 +30,7 @@ import java.util.stream.Collectors;
public class RankService {
@Resource
private OkHttpClient okHttpClient;
private HttpHandler httpHandler;
@Value("${url.rank}")
private String url;
@@ -45,38 +40,19 @@ public class RankService {
.fluentPut("query", question)
.fluentPut("tags", queryTags)
.fluentPut("record", rankDataList);
RequestBody requestBody = RequestBody.create(OkHttpConfig.JSON, params.toJSONString());
Request request = new Request.Builder()
.url(url)
.post(requestBody)
.build();
Response response = null;
JSONObject linkResp;
String resp;
JSONObject resp;
try {
response = okHttpClient.newCall(request).execute();
if (!response.isSuccessful()) {
throw new GraphServiceException(ReturnCodeEnum.ALGO_ERROR, "rank response is not successful: " + response.code());
}
if (response.body() == null) {
throw new GraphServiceException(ReturnCodeEnum.ALGO_ERROR, "rank response is null");
}
resp = response.body().string();
if (StringUtils.isBlank(resp)) {
String response = httpHandler.doPost(url, new JSONObject(), new JSONObject(), params.toJSONString());
if (StringUtils.isEmpty(response)) {
throw new GraphServiceException(ReturnCodeEnum.ALGO_ERROR, "rank response is empty");
}
linkResp = JSON.parseObject(resp);
ThirdLog.info("rank", params.toJSONString(), resp, requestId);
resp = JSON.parseObject(response);
ThirdLog.info("rank", params.toJSONString(), response, requestId);
} catch (Exception e) {
ThirdLog.error("rank", params.toJSONString(), e, requestId);
throw new GraphServiceException(ReturnCodeEnum.ALGO_ERROR, e.getMessage());
} finally {
IOUtils.closeQuietly(response);
throw e;
}
if (!linkResp.containsKey("data")) {
return Lists.newArrayList();
}
JSONArray data = linkResp.getJSONArray("data");
JSONArray data = resp.getJSONArray("data");
if (CollectionUtils.isEmpty(data)) {
return Lists.newArrayList();
}

View File

@@ -2,28 +2,20 @@ package com.shuwen.data.entity.manage.service.external.similarity;
import cn.hutool.core.lang.Pair;
import cn.hutool.core.util.IdUtil;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.shuwen.data.entity.manage.api.model.enums.ReturnCodeEnum;
import com.shuwen.data.entity.manage.api.model.exception.AlgoServiceException;
import com.shuwen.data.entity.manage.api.model.exception.GraphServiceException;
import com.shuwen.data.entity.manage.service.config.OkHttpConfig;
import com.shuwen.data.entity.manage.service.http.HttpHandler;
import com.shuwen.data.entity.manage.service.util.SegmentUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.math.BigDecimal;
import java.util.List;
/**
@@ -37,7 +29,7 @@ import java.util.List;
public class EmbeddingService {
@Resource
private OkHttpClient okHttpClient;
private HttpHandler httpHandler;
@Value("${url.embedding}")
private String url;
@@ -51,78 +43,59 @@ public class EmbeddingService {
* @param input
* @return
*/
public List<Pair<String, List<Float>>> embeddings(String input) throws AlgoServiceException {
public List<Pair<String, List<Float>>> embeddings(String input) {
List<String> chunks = SegmentUtils.chunk(SegmentUtils.sentences(input));
JSONObject body = new JSONObject();
body.put("request_id", IdUtil.simpleUUID());
body.put("biz_name", "wenlv");
body.put("content_text", chunks);
body.put("type", embeddingModel);
String resStr = null;
try {
List<Pair<String, List<Float>>> result = Lists.newArrayList();
HttpResponse response = HttpUtil.createPost(url + "/embedding/generate")
.body(body.toJSONString()).execute();
resStr = response.body();
if (response.getStatus() != 200 || StringUtils.isEmpty(resStr)) {
throw new AlgoServiceException("embedding error " + resStr);
}
JSONObject resObj = JSON.parseObject(response.body());
JSONArray data = resObj.getJSONArray("data");
for (int i = 0; i < chunks.size(); i++) {
JSONArray obj = data.getJSONArray(i);
result.add(new Pair<>(chunks.get(i), obj.toJavaList(Float.class)));
}
return result;
} catch (Exception e) {
log.error(String.format("embeddings error body is %s res is %s", body.toJSONString(), resStr), e);
if (e instanceof AlgoServiceException) {
throw e;
}
return null;
List<Pair<String, List<Float>>> result = Lists.newArrayList();
String response = httpHandler.doPost(url + "/embedding/generate", new JSONObject(), new JSONObject(), body.toJSONString());
if (StringUtils.isEmpty(response)) {
throw new AlgoServiceException("embedding error, response: " + response);
}
JSONObject ret;
try {
ret = JSON.parseObject(response);
} catch (Exception e) {
log.error("can not parser response: {}, response: {}", JSON.toJSONString(chunks), response, e);
throw e;
}
if (!ret.getBooleanValue("success")) {
throw new AlgoServiceException("embedding error, response: " + response);
}
JSONArray data = ret.getJSONArray("data");
for (int i = 0; i < chunks.size(); i++) {
JSONArray obj = data.getJSONArray(i);
result.add(new Pair<>(chunks.get(i), obj.toJavaList(Float.class)));
}
return result;
}
public List<Float> similarity(String question, List<String> chunkList) {
JSONObject params = new JSONObject()
.fluentPut("bizName", "groot")
.fluentPut("query", question)
.fluentPut("sentences", chunkList)
.fluentPut("type", "flag");
RequestBody requestBody = RequestBody.create(OkHttpConfig.JSON, params.toJSONString());
Request request = new Request.Builder()
.url(url + "/embedding/similarity")
.post(requestBody)
.build();
Response response = null;
JSONObject linkResp;
String resp;
public List<List<BigDecimal>> embeddings(List<String> textList) {
JSONObject reqBody = new JSONObject()
.fluentPut("content_text", textList);
reqBody.put("type", embeddingModel);
String response = httpHandler.doPost(url + "/embedding/generate", new JSONObject(), new JSONObject(), reqBody.toJSONString());
if (StringUtils.isEmpty(response)) {
throw new AlgoServiceException("embedding error, response: " + response);
}
JSONObject ret;
try {
response = okHttpClient.newCall(request).execute();
if (response.body() == null || response.code() == 500) {
throw new GraphServiceException(ReturnCodeEnum.ALGO_ERROR, "entity link response is null");
}
resp = response.body().string();
if (StringUtils.isBlank(resp)) {
throw new GraphServiceException(ReturnCodeEnum.ALGO_ERROR, "entity link response is empty");
}
linkResp = JSON.parseObject(resp);
ret = JSON.parseObject(response);
} catch (Exception e) {
throw new GraphServiceException(ReturnCodeEnum.ALGO_ERROR, e.getMessage());
} finally {
IOUtils.closeQuietly(response);
log.error("can not parser response: {}, response: {}", JSON.toJSONString(textList), response, e);
throw e;
}
if (!linkResp.containsKey("data")) {
return Lists.newArrayList();
if (!ret.getBooleanValue("success")) {
throw new AlgoServiceException("embedding error, response: " + response);
}
JSONObject data = linkResp.getJSONObject("data");
if (!data.containsKey("flag")) {
return Lists.newArrayList();
JSONArray data = ret.getJSONArray("data");
List<List<BigDecimal>> results = Lists.newArrayList();
for (int i = 0; i < data.size(); i++) {
results.add(data.getJSONArray(i).toJavaList(BigDecimal.class));
}
JSONArray flag = data.getJSONArray("flag");
if (null == flag) {
return Lists.newArrayList();
}
return flag.toJavaList(Float.class);
return results;
}
}

View File

@@ -0,0 +1,103 @@
package com.shuwen.data.entity.manage.service.http;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.TypeReference;
import com.shuwen.data.entity.manage.api.model.enums.ReturnCodeEnum;
import com.shuwen.data.entity.manage.api.model.exception.GraphServiceException;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.Objects;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2024/7/15 17:46
*/
public abstract class BaseHttpHandler {
abstract OkHttpClient getClient();
public <T> T doPost(String url, String body, TypeReference<T> typeReference) {
return doPost(url, null, null, body, typeReference);
}
public <T> T doPost(String url, JSONObject queries, JSONObject headers, String body, TypeReference<T> typeReference) {
String response = doPost(url, queries, headers, body);
if (StringUtils.isEmpty(response)) {
return null;
}
return JSONObject.parseObject(response, typeReference);
}
public String doPost(String url, JSONObject queries, JSONObject headers, String body) {
String exactUrl = url(url, queries);
if (StringUtils.isEmpty(exactUrl)) {
return null;
}
RequestBody requestBody = RequestBody.create(body, MediaType.parse("application/json; charset=utf-8"));
Request request = builder(exactUrl, headers)
.post(requestBody)
.build();
return execute(request);
}
public String doGet(String url, JSONObject queries, JSONObject headers) {
String exactUrl = url(url, queries);
if (StringUtils.isEmpty(exactUrl)) {
return null;
}
Request request = builder(exactUrl, headers)
.build();
return execute(request);
}
private String url(String url, JSONObject queries) {
HttpUrl httpUrl = HttpUrl.parse(url);
if (httpUrl == null) {
return null;
}
HttpUrl.Builder urlBuilder = httpUrl.newBuilder();
if (MapUtils.isNotEmpty(queries)) {
queries.forEach((key, value) -> urlBuilder.addQueryParameter(key, (String) value));
}
return urlBuilder.toString();
}
private Request.Builder builder(String url, JSONObject headers) {
Request.Builder builder = new Request.Builder().url(url);
if (MapUtils.isNotEmpty(headers)) {
headers.forEach((key, value) -> builder.addHeader(key, (String) value));
}
return builder;
}
private String execute(Request request) {
Response response = null;
try {
response = getClient().newCall(request).execute();
ResponseBody body = response.body();
if (body == null) {
return null;
}
if (!response.isSuccessful()) {
throw new GraphServiceException(ReturnCodeEnum.HTTP_ERROR, "http error " + response.code());
}
return body.string();
} catch (Exception e) {
throw new GraphServiceException(ReturnCodeEnum.HTTP_ERROR, "http error", e);
} finally {
if (Objects.nonNull(response)) {
response.close();
}
}
}
}

View File

@@ -0,0 +1,24 @@
package com.shuwen.data.entity.manage.service.http;
import okhttp3.OkHttpClient;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
/**
* Project: groot-material
* Description:
* Author: Kenn
* Create: 2022/12/2 16:50
*/
@Component
public class HttpHandler extends BaseHttpHandler {
@Resource
private OkHttpClient okHttpClient;
@Override
protected OkHttpClient getClient() {
return okHttpClient;
}
}

View File

@@ -14,6 +14,7 @@ import com.shuwen.data.entity.manage.api.service.IEdgeManageService;
import com.shuwen.data.entity.manage.common.entity.router.EntityFieldRouter;
import com.shuwen.data.entity.manage.common.utils.AssertParamUtils;
import com.shuwen.data.entity.manage.common.utils.BeanUtils;
import com.shuwen.data.entity.manage.service.external.mq.ProducerHandler;
import com.shuwen.data.entity.manage.service.integration.configmap.SwitchConfig;
import com.shuwen.data.entity.manage.service.internal.AuthorityUtils;
import com.shuwen.data.entity.manage.service.internal.ICheckService;
@@ -85,6 +86,9 @@ public class EdgeManageService implements IEdgeManageService {
@Resource
private PropMergeService propMergeService;
@Resource
private ProducerHandler producerHandler;
@Value("${deploy.mode}")
private String deployMode;
@@ -116,6 +120,10 @@ public class EdgeManageService implements IEdgeManageService {
writeEdge(graph, internalBaseEdge, ModifyActionEnum.ADD);
EdgeModifyResult crudResult = new EdgeModifyResult("edgeAdd", GraphRegisterUtils.genEdgeKey(graph, internalBaseEdge, DIRECTION_OUT), null);
crudResult.setModify(true);
// 发送MQ
producerHandler.sendEdge(GraphRegisterUtils.genEdgeKey(graph, internalBaseEdge, DIRECTION_OUT), graph);
return crudResult;
}
@@ -175,6 +183,10 @@ public class EdgeManageService implements IEdgeManageService {
internalBaseEdge.setProperties(modifyData);
writeEdge(graph, internalBaseEdge, ModifyActionEnum.UPDATE);
}
// 发送MQ
producerHandler.sendEdge(edgeId, graph);
return new EdgeModifyResult("edgeUpdate", edgeId, modifyField);
}
@@ -263,6 +275,10 @@ public class EdgeManageService implements IEdgeManageService {
} finally {
graphWriteLock.edgeRelease(edgeId);
}
// 发送MQ
producerHandler.sendEdge(edgeId, graph);
EdgeModifyResult result = new EdgeModifyResult("edgeRemove", edgeId, null);
result.setModify(true);
return result;

View File

@@ -3,6 +3,7 @@ package com.shuwen.data.entity.manage.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Sets;
import com.shuwen.data.entity.manage.api.model.base.AbstractRequest;
import com.shuwen.data.entity.manage.api.model.chat.GraphVectorInfo;
import com.shuwen.data.entity.manage.api.model.graph.base.BaseGraphEdge;
import com.shuwen.data.entity.manage.api.model.graph.base.GraphPathResult;
import com.shuwen.data.entity.manage.api.model.graph.base.SnowflakeQueryItem;
@@ -11,11 +12,13 @@ import com.shuwen.data.entity.manage.api.model.graph.base.option.FetchOption;
import com.shuwen.data.entity.manage.api.model.graph.base.option.RelOption;
import com.shuwen.data.entity.manage.api.model.graph.dto.GraphEdgeDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.GraphSnowflakeDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VectorSearchDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VertexDetailDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VertexSearchDto;
import com.shuwen.data.entity.manage.api.model.graph.result.GraphEdgeResult;
import com.shuwen.data.entity.manage.api.model.search.EntitySearchResult;
import com.shuwen.data.entity.manage.api.service.IGraphDiscoverService;
import com.shuwen.data.entity.manage.api.service.IGraphVectorService;
import com.shuwen.data.entity.manage.api.service.IVertexFetchService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
@@ -24,7 +27,6 @@ import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -51,6 +53,9 @@ public class GraphChatRetrievalService {
@Resource
private IGraphDiscoverService graphDiscoverService;
@Resource
private IGraphVectorService graphVectorService;
private static final Set<String> FETCH_VERTEX_PROPS = Sets.newHashSet("name", "summary", "image", "content", "tags", "struct_address");
private static final Set<String> FETCH_EXCLUDE_PROPS = Sets.newHashSet("resume", "trail", "trail_info", "project", "short_summary", "summary_list", "data_source", "importance_info");
@@ -155,7 +160,6 @@ public class GraphChatRetrievalService {
return result.getPaths();
}
/**
* 地理-扩散查询
* 省-查三跳;市-查二跳
@@ -201,6 +205,14 @@ public class GraphChatRetrievalService {
return result.getPaths();
}
public List<GraphVectorInfo> vector(String query, String graph, int size) {
VectorSearchDto reqDto = new VectorSearchDto();
baseWrap(reqDto, graph);
reqDto.setQuery(query);
reqDto.setSize(size);
return graphVectorService.search(reqDto);
}
private void baseWrap(AbstractRequest req, String graph) {
req.setAppId("internal-chat");
req.setUid("groot");

View File

@@ -16,6 +16,7 @@ import com.shuwen.data.entity.manage.api.model.chat.FullTextRetrievalInfo;
import com.shuwen.data.entity.manage.api.model.chat.FuncNameConstants;
import com.shuwen.data.entity.manage.api.model.chat.FuncOriginal;
import com.shuwen.data.entity.manage.api.model.chat.FuncResult;
import com.shuwen.data.entity.manage.api.model.chat.GraphVectorInfo;
import com.shuwen.data.entity.manage.api.model.chat.IntentInfo;
import com.shuwen.data.entity.manage.api.model.chat.IntentRephrase;
import com.shuwen.data.entity.manage.api.model.chat.RankData;
@@ -24,12 +25,15 @@ import com.shuwen.data.entity.manage.api.model.chat.RetrievalInfo;
import com.shuwen.data.entity.manage.api.model.chat.RetrievalRequestDto;
import com.shuwen.data.entity.manage.api.model.chat.RetrievalResult;
import com.shuwen.data.entity.manage.api.model.chat.Slots;
import com.shuwen.data.entity.manage.api.model.chat.VectorRetrievalInfo;
import com.shuwen.data.entity.manage.api.model.enums.ReturnCodeEnum;
import com.shuwen.data.entity.manage.api.model.exception.GraphServiceException;
import com.shuwen.data.entity.manage.api.model.graph.base.BaseGraphEdge;
import com.shuwen.data.entity.manage.api.model.graph.base.VertexNeighbor;
import com.shuwen.data.entity.manage.api.model.graph.base.VertexPath;
import com.shuwen.data.entity.manage.api.model.graph.dto.GraphNeighborDto;
import com.shuwen.data.entity.manage.api.model.graph.result.EdgeVectorResult;
import com.shuwen.data.entity.manage.api.model.graph.result.VertexVectorResult;
import com.shuwen.data.entity.manage.api.service.IGraphChatService;
import com.shuwen.data.entity.manage.service.algo.AlgoService;
import com.shuwen.data.entity.manage.service.chat.ChatHandler;
@@ -129,7 +133,7 @@ public class GraphChatService implements IGraphChatService {
private Executor graphGetExecutor;
@Resource
private Executor fullTextExecutor;
private Executor multiRouteRecallExecutor;
@Resource
private Executor labelFullTextExecutor;
@@ -246,7 +250,17 @@ public class GraphChatService implements IGraphChatService {
log.error("full text retrieval error", e);
return Pair.of(g, new FullTextRetrievalInfo());
}
}, fullTextExecutor));
}, multiRouteRecallExecutor));
}
}
}
// 向量召回
List<CompletableFuture<Pair<String, VectorRetrievalInfo>>> vectorFutures = Lists.newArrayList();
if (ChatOptionUtils.vectorRecall(option)) {
for (String g : graphList) {
if (!g.equals("tsldzg")) {
vectorFutures.add(CompletableFuture.supplyAsync(() -> Pair.of(g, vector(question, g, requestId)), multiRouteRecallExecutor));
}
}
}
@@ -270,6 +284,18 @@ public class GraphChatService implements IGraphChatService {
retrievalInfo.setFullTextRecallEntityMap(fullTextRetrievalInfo.getFullTextRecallEntityMap());
retrievalInfo.putUsage(fullTextRetrievalInfo.getUsage());
}
// 向量结果获取
Map<String, VectorRetrievalInfo> vectorRetrievalInfoMap = vectorFutures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
VectorRetrievalInfo vectorRetrievalInfo = vectorRetrievalInfoMap.get(graph);
if (vectorRetrievalInfo != null) {
retrievalInfo.setVectorRecallEntityMap(vectorRetrievalInfo.getVectorRecallEntityMap());
if (MapUtils.isNotEmpty(vectorRetrievalInfo.getVectorRecallRelationPathMap())) {
retrievalInfo.setVectorRecallRelationPathList(Lists.newArrayList(vectorRetrievalInfo.getVectorRecallRelationPathMap().values()));
}
}
} else {
// 多图召回
List<CompletableFuture<Triple<String, SlotsContext, RetrievalInfo>>> futureList = Lists.newArrayList();
@@ -286,7 +312,12 @@ public class GraphChatService implements IGraphChatService {
.map(CompletableFuture::join)
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
// FC召回结果和全文结果拼接
// 向量结果获取
Map<String, VectorRetrievalInfo> vectorRetrievalInfoMap = vectorFutures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
// FC召回结果、全文结果、向量结果拼接
Map<String, Pair<SlotsContext, RetrievalInfo>> results = Maps.newConcurrentMap();
futureResultList.forEach(triple -> {
FullTextRetrievalInfo fullTextRetrievalInfo = fullTextRetrievalInfoMap.get(triple.getLeft());
@@ -294,6 +325,13 @@ public class GraphChatService implements IGraphChatService {
triple.getRight().setFullTextRecallEntityMap(fullTextRetrievalInfo.getFullTextRecallEntityMap());
triple.getRight().putUsage(fullTextRetrievalInfo.getUsage());
}
VectorRetrievalInfo vectorRetrievalInfo = vectorRetrievalInfoMap.get(triple.getLeft());
if (vectorRetrievalInfo != null && triple.getRight() != null) {
triple.getRight().setVectorRecallEntityMap(vectorRetrievalInfo.getVectorRecallEntityMap());
if (MapUtils.isNotEmpty(vectorRetrievalInfo.getVectorRecallRelationPathMap())) {
triple.getRight().setVectorRecallRelationPathList(Lists.newArrayList(vectorRetrievalInfo.getVectorRecallRelationPathMap().values()));
}
}
results.put(triple.getLeft(), Pair.of(triple.getMiddle(), triple.getRight()));
});
@@ -503,7 +541,6 @@ public class GraphChatService implements IGraphChatService {
}
start = System.currentTimeMillis();
List<JSONObject> filterDetails = Lists.newArrayList();
fullTextRetrievalInfo.setRankDataList(rankDataList);
List<Float> ranks = rankService.rank(question, queryTags, rankDataList, requestId);
if (CollectionUtils.isEmpty(ranks) || ranks.size() != rankDataList.size()) {
@@ -512,6 +549,7 @@ public class GraphChatService implements IGraphChatService {
fullTextRetrievalInfo.setRankScoreList(ranks);
fullTextRetrievalInfo.putUsage("full_text_rank", System.currentTimeMillis() - start);
List<JSONObject> filterDetails = Lists.newArrayList();
for (int i = 0; i < details.size(); i++) {
float rank = ranks.get(i);
if (rank >= GraphRecallParamsConfig.getFullTextThreshold()) {
@@ -534,6 +572,79 @@ public class GraphChatService implements IGraphChatService {
return fullTextRetrievalInfo;
}
private VectorRetrievalInfo vector(String question, String graph, String requestId) {
VectorRetrievalInfo vectorRetrievalInfo = new VectorRetrievalInfo();
long start = System.currentTimeMillis();
List<GraphVectorInfo> vectorInfoList = graphChatRetrievalService.vector(question, graph, 10);
vectorRetrievalInfo.putUsage("vector_recall", System.currentTimeMillis() - start);
// 排序
start = System.currentTimeMillis();
List<RankData> rankDataList = Lists.newArrayList();
for (GraphVectorInfo vectorInfo : vectorInfoList) {
RankData rankData = new RankData("GRAPH", vectorInfo.getContent(), null);
rankDataList.add(rankData);
}
List<Float> ranks = rankService.rank(question, null, rankDataList, requestId);
if (CollectionUtils.isEmpty(ranks) || ranks.size() != rankDataList.size()) {
return vectorRetrievalInfo;
}
vectorRetrievalInfo.setRankScoreList(ranks);
vectorRetrievalInfo.putUsage("full_text_rank", System.currentTimeMillis() - start);
// 过滤
List<Pair<Float, GraphVectorInfo>> filters = Lists.newArrayList();
for (int i = 0; i < vectorInfoList.size(); i++) {
float rank = ranks.get(i);
if (rank >= GraphRecallParamsConfig.getVectorThreshold()) {
GraphVectorInfo vectorInfo = vectorInfoList.get(i);
filters.add(Pair.of(rank, vectorInfo));
}
}
if (CollectionUtils.isEmpty(filters)) {
return vectorRetrievalInfo;
}
filters.sort((o1, o2) -> {
float rank1 = o1.getKey();
float rank2 = o2.getKey();
return Float.compare(rank2, rank1);
});
for (int i = 0; i < Math.min(filters.size(), 5); i++) {
GraphVectorInfo graphVectorInfo = filters.get(i).getRight();
if (graphVectorInfo.getVectorDataType() == GraphVectorInfo.VectorDataType.VERTEX) {
VertexVectorResult result = (VertexVectorResult) graphVectorInfo;
vectorRetrievalInfo.addEntity(result.getEntityId(), result.getEntity());
} else if (graphVectorInfo.getVectorDataType() == GraphVectorInfo.VectorDataType.EDGE) {
EdgeVectorResult result = (EdgeVectorResult) graphVectorInfo;
RelationPath relationPath = new RelationPath();
relationPath.setSourceLabel(result.getLabelOut());
relationPath.setSourceId(result.getEntityIdOut());
relationPath.setTargetLabel(result.getLabelIn());
relationPath.setTargetId(result.getEntityIdIn());
relationPath.setCorpus(result.getContent());
relationPath.setFillInfo(false);
EdgePathElement element = new EdgePathElement();
element.setRelName(result.getRelName());
element.setLabelOut(result.getLabelOut());
element.setEntityIdOut(result.getEntityIdOut());
element.setEntityNameOut(result.getSource().get("name").toString());
element.setLabelIn(result.getLabelIn());
element.setEntityIdIn(result.getEntityIdIn());
element.setEntityNameIn(result.getTarget().get("name").toString());
relationPath.setRelationList(Lists.newArrayList(element));
relationPath.addEdgeId(result.getEdgeId());
vectorRetrievalInfo.addRelationPath(result.getEdgeId(), relationPath);
vectorRetrievalInfo.addEntity(result.getEntityIdOut(), result.getSource());
vectorRetrievalInfo.addEntity(result.getEntityIdIn(), result.getTarget());
}
}
return vectorRetrievalInfo;
}
private RetrievalInfo retrieval(String graph, IntentInfo info, String restrict, Set<String> props, JSONObject option, Boolean tripPlan, Boolean needVisualization, String question) {
Slots originalSlots = info.getSlots();
IntentRephrase rephrase = SlotsRephraseUtils.rephrase(info);
@@ -1153,7 +1264,7 @@ public class GraphChatService implements IGraphChatService {
result.addCorpus(emptyStr);
// 后续处理
postBuildResult(retrievalInfo, props, needVisualization, result, false, Sets.newHashSet());
postBuildResult(retrievalInfo, props, needVisualization, result, false, false, Sets.newHashSet(), Sets.newHashSet());
return;
}
@@ -1199,16 +1310,18 @@ public class GraphChatService implements IGraphChatService {
}
// 后续处理
postBuildResult(retrievalInfo, props, needVisualization, result, true, Sets.newHashSet());
postBuildResult(retrievalInfo, props, needVisualization, result, true, true, Sets.newHashSet(), Sets.newHashSet());
return;
}
// 关系语料
Set<String> hasAddCorpusEntities = Sets.newHashSet();
Set<String> hasAddCorpusRelations = Sets.newHashSet();
for (RelationPath relationPath : retrievalInfo.getRecallRelationPathList()) {
if (StringUtils.isNotEmpty(relationPath.getCorpus())) {
result.addRelationDetail(new JSONObject().fluentPut("corpus", relationPath.getCorpus()).fluentPut("path", relationPath.getRelationList()));
result.addCorpus(relationPath.getCorpus());
hasAddCorpusRelations.addAll(relationPath.getEdgeIds());
// 添加语料对应实体信息
if (StringUtils.isNotEmpty(relationPath.getSourceId()) && StringUtils.isNotEmpty(relationPath.getTargetId())) {
@@ -1372,10 +1485,10 @@ public class GraphChatService implements IGraphChatService {
}
// 后续处理
postBuildResult(retrievalInfo, props, needVisualization, result, true, hasAddCorpusEntities);
postBuildResult(retrievalInfo, props, needVisualization, result, true, true, hasAddCorpusEntities, hasAddCorpusRelations);
}
private void postBuildResult(RetrievalInfo retrievalInfo, Set<String> props, boolean needVisualization, RetrievalResult result, boolean addFulltext, Set<String> hasAddCorpusEntities) {
private void postBuildResult(RetrievalInfo retrievalInfo, Set<String> props, boolean needVisualization, RetrievalResult result, boolean addFulltext, boolean addVector, Set<String> hasAddCorpusEntities, Set<String> hasAddCorpusRelations) {
// 补充全文检索结果
if (addFulltext && MapUtils.isNotEmpty(retrievalInfo.getFullTextRecallEntityMap())) {
List<JSONObject> fullTextItemList = retrievalInfo.getFullTextRecallEntityMap().values().parallelStream()
@@ -1388,6 +1501,39 @@ public class GraphChatService implements IGraphChatService {
result.addEntityDetail(fulltextItem);
result.addCorpus(corpus);
result.addCorpusEntity(corpus, Lists.newArrayList(entityId), "FULL_TEXT");
hasAddCorpusEntities.add(entityId);
}
}
// 补充向量检索结果
if (addVector) {
if (MapUtils.isNotEmpty(retrievalInfo.getVectorRecallEntityMap())) {
List<Map<String, Object>> vectorEntityList = retrievalInfo.getVectorRecallEntityMap().entrySet().parallelStream()
.filter(entry -> !hasAddCorpusEntities.contains(entry.getKey()))
.map(Map.Entry::getValue)
.collect(Collectors.toList());
List<JSONObject> vectorItemList = vectorEntityList.parallelStream()
.map(entity -> entityInfoWrap.getEntityInfo(entity, props, true, retrievalInfo.getGraph()))
.collect(Collectors.toList());
for (JSONObject vectorItem : vectorItemList) {
String entityId = vectorItem.getString("id");
String corpus = vectorItem.getString("corpus");
result.addEntityDetail(vectorItem);
result.addCorpus(corpus);
result.addCorpusEntity(corpus, Lists.newArrayList(entityId), "VECTOR");
hasAddCorpusEntities.add(entityId);
}
}
if (CollectionUtils.isNotEmpty(retrievalInfo.getVectorRecallRelationPathList())) {
List<RelationPath> vectorRecallRelationPathList = retrievalInfo.getVectorRecallRelationPathList();
for (RelationPath relationPath : vectorRecallRelationPathList) {
if (CollectionUtils.containsAll(hasAddCorpusRelations, relationPath.getEdgeIds())) {
continue;
}
result.addCorpusEntity(relationPath.getCorpus(), Lists.newArrayList(relationPath.getSourceId(), relationPath.getTargetId()));
result.addRelationDetail(new JSONObject().fluentPut("corpus", relationPath.getCorpus()).fluentPut("path", relationPath.getRelationList()));
}
}
}

View File

@@ -0,0 +1,746 @@
package com.shuwen.data.entity.manage.service.impl;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.shuwen.data.entity.manage.api.model.chat.EdgeVectorInfo;
import com.shuwen.data.entity.manage.api.model.chat.GraphVectorInfo;
import com.shuwen.data.entity.manage.api.model.chat.VertexVectorInfo;
import com.shuwen.data.entity.manage.api.model.enums.ReturnCodeEnum;
import com.shuwen.data.entity.manage.api.model.exception.GraphServiceException;
import com.shuwen.data.entity.manage.api.model.exception.QueueServiceException;
import com.shuwen.data.entity.manage.api.model.exception.QueueServiceException.QueueTypeEnum;
import com.shuwen.data.entity.manage.api.model.graph.dto.EdgeCrudDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.EdgeVectorCrudDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VectorSearchDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VertexDetailDto;
import com.shuwen.data.entity.manage.api.model.graph.dto.VertexVectorCrudDto;
import com.shuwen.data.entity.manage.api.model.graph.result.EdgeDetailResult;
import com.shuwen.data.entity.manage.api.model.graph.result.EdgeVectorResult;
import com.shuwen.data.entity.manage.api.model.graph.result.VertexVectorResult;
import com.shuwen.data.entity.manage.api.service.IEdgeManageService;
import com.shuwen.data.entity.manage.api.service.IGraphVectorService;
import com.shuwen.data.entity.manage.api.service.IVertexFetchService;
import com.shuwen.data.entity.manage.common.entity.constant.GraphConstants;
import com.shuwen.data.entity.manage.common.utils.AssertParamUtils;
import com.shuwen.data.entity.manage.common.utils.ids.IDUtils;
import com.shuwen.data.entity.manage.service.chat.utils.ContentToStrUtils;
import com.shuwen.data.entity.manage.service.external.ontology.OntologyCacheManager;
import com.shuwen.data.entity.manage.service.external.similarity.EmbeddingService;
import com.shuwen.data.entity.manage.service.integration.configmap.GraphVectorParamsConfig;
import com.shuwen.data.entity.manage.service.lock.GraphLockUtils;
import com.shuwen.data.entity.manage.service.lock.IGraphVectorWriteLock;
import com.shuwen.data.entity.manage.service.ontology.module.VertexDesc;
import com.shuwen.data.entity.manage.service.task.module.l1.RelationDesc;
import com.shuwen.data.entity.manage.service.task.module.l2.PropDesc;
import com.shuwen.data.entity.manage.service.task.module.l2.SubClassDesc;
import com.shuwen.data.entity.manage.service.util.LLMStreamingUtils;
import com.shuwen.data.entity.manage.service.util.NoSolutionUtils;
import com.shuwen.llm.enums.LLMType;
import com.shuwen.llm.message.GPTMessage;
import com.shuwen.llm.message.LLMMessage;
import com.shuwen.llm.service.LLMService;
import com.shuwen.search.proxy.api.entity.base.BoolQuery;
import com.shuwen.search.proxy.api.entity.base.FieldFilter;
import com.shuwen.search.proxy.api.entity.base.KNNVector;
import com.shuwen.search.proxy.api.entity.dto.bulk.BulkReqDto;
import com.shuwen.search.proxy.api.entity.dto.bulk.BulkRespDto;
import com.shuwen.search.proxy.api.entity.dto.bulk.OperateBulk;
import com.shuwen.search.proxy.api.entity.dto.common.AbstractCommonRequest;
import com.shuwen.search.proxy.api.entity.dto.common.DeleteByQueryReqDto;
import com.shuwen.search.proxy.api.entity.dto.common.FilterReqDto;
import com.shuwen.search.proxy.api.entity.dto.common.ItemsRespDto;
import com.shuwen.search.proxy.api.entity.enums.FieldFilterTypeEnum;
import com.shuwen.search.proxy.api.service.IBulkService;
import com.shuwen.search.proxy.api.service.IFilterService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import javax.annotation.Resource;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import static com.shuwen.data.entity.manage.service.chat.utils.ContentToStrUtils.PAIR_REL_NAME;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/19 15:09
*/
@Slf4j
@Service
public class GraphVectorServiceImpl implements IGraphVectorService {
@Resource
private IVertexFetchService vertexFetchService;
@Resource
private IEdgeManageService edgeManageService;
@Resource(name = "vectorFilterService")
private IFilterService filterService;
@Resource(name = "vectorBulkService")
private IBulkService bulkService;
@Resource
private IGraphVectorWriteLock graphVectorWriteLock;
@Resource
private LLMService llmService;
@Resource
private EmbeddingService embeddingService;
private static final List<String> UNUSED_VERTEX_FIELD = Lists.newArrayList("entity_id", "image", "resume", "trail", "trail_info", "biz_info", "project", "hot", "valid", "create_time", "update_time", "entity_id", "label", "_type", "summary_list", "data_source", "importance_info");
private static final List<String> UNUSED_EDGE_FIELD = Lists.newArrayList("edge_id", "rel_name", "label_out", "label_in", "entity_id_out", "entity_id_in", "subclass_out", "subclass_in", "rel_id", "direction_certain", "project", "create_time", "update_time", "desc_list", "expired");
@Override
public void create(VertexVectorCrudDto dto) {
AssertParamUtils.isBlank(dto.getLabel(), ReturnCodeEnum.PARAM_INVALID, "label can not be null");
AssertParamUtils.isBlank(dto.getEntityId(), ReturnCodeEnum.PARAM_INVALID, "entityId can not be null");
Map<String, Object> detail = detail(dto.getEntityId(), dto.getLabel(), dto.getGraph());
AssertParamUtils.isNull(detail, ReturnCodeEnum.EMPTY, "the entity be null");
save(dto.getEntityId(), dto.getLabel(), detail, dto.getGraph());
}
@Override
public void delete(VertexVectorCrudDto dto) {
AssertParamUtils.isBlank(dto.getLabel(), ReturnCodeEnum.PARAM_INVALID, "label can not be null");
AssertParamUtils.isBlank(dto.getEntityId(), ReturnCodeEnum.PARAM_INVALID, "entityId can not be null");
delete(dto.getEntityId(), dto.getLabel(), dto.getGraph());
}
@Override
public void rebuild(VertexVectorCrudDto dto) {
AssertParamUtils.isBlank(dto.getLabel(), ReturnCodeEnum.PARAM_INVALID, "label can not be null");
AssertParamUtils.isBlank(dto.getEntityId(), ReturnCodeEnum.PARAM_INVALID, "entityId can not be null");
Map<String, Object> detail = detail(dto.getEntityId(), dto.getLabel(), dto.getGraph());
if (detail == null) {
delete(dto.getEntityId(), dto.getLabel(), dto.getGraph());
} else {
save(dto.getEntityId(), dto.getLabel(), detail, dto.getGraph());
}
}
@Override
public void create(EdgeVectorCrudDto dto) {
AssertParamUtils.isBlank(dto.getEdgeId(), ReturnCodeEnum.PARAM_INVALID, "edgeId can not be null");
Map<String, Object> detail = detail(dto.getEdgeId(), dto.getGraph());
String relName = detail.get("rel_name").toString();
String labelOut = detail.get("label_out").toString();
String entityIdOut = detail.get("entity_id_out").toString();
String labelIn = detail.get("label_in").toString();
String entityIdIn = detail.get("entity_id_in").toString();
Map<String, Object> entityDetailOut = detail(entityIdOut, labelOut, dto.getGraph());
String entityNameOut = entityDetailOut.get("name").toString();
Map<String, Object> entityDetailIn = detail(entityIdIn, labelIn, dto.getGraph());
String entityNameIn = entityDetailIn.get("name").toString();
save(dto.getEdgeId(), relName, labelOut, entityIdOut, entityNameOut, labelIn, entityIdIn, entityNameIn, detail, dto.getGraph());
}
@Override
public void delete(EdgeVectorCrudDto dto) {
AssertParamUtils.isBlank(dto.getEdgeId(), ReturnCodeEnum.PARAM_INVALID, "edgeId can not be null");
delete(dto.getEdgeId(), dto.getGraph());
}
@Override
public void rebuild(EdgeVectorCrudDto dto) {
AssertParamUtils.isBlank(dto.getEdgeId(), ReturnCodeEnum.PARAM_INVALID, "edgeId can not be null");
Map<String, Object> detail = detail(dto.getEdgeId(), dto.getGraph());
if (detail == null) {
delete(dto.getEdgeId(), dto.getGraph());
} else {
String relName = detail.get("rel_name").toString();
String labelOut = detail.get("label_out").toString();
String entityIdOut = detail.get("entity_id_out").toString();
String labelIn = detail.get("label_in").toString();
String entityIdIn = detail.get("entity_id_in").toString();
Map<String, Object> entityDetailOut = detail(entityIdOut, labelOut, dto.getGraph());
String entityNameOut = entityDetailOut.get("name").toString();
Map<String, Object> entityDetailIn = detail(entityIdIn, labelIn, dto.getGraph());
String entityNameIn = entityDetailIn.get("name").toString();
save(dto.getEdgeId(), relName, labelOut, entityIdOut, entityNameOut, labelIn, entityIdIn, entityNameIn, detail, dto.getGraph());
}
}
@Override
public List<GraphVectorInfo> search(VectorSearchDto dto) {
AssertParamUtils.isBlank(dto.getQuery(), ReturnCodeEnum.PARAM_INVALID, "query can not be null");
List<List<BigDecimal>> features = embeddingService.embeddings(Lists.newArrayList(dto.getQuery()));
BoolQuery query = wrapQuery(dto, features.get(0), dto.getSize());
FilterReqDto filterReqDto = new FilterReqDto();
wrapReq(filterReqDto);
filterReqDto.setSize(dto.getSize());
filterReqDto.must(new FieldFilter(FieldFilterTypeEnum.BOOL, query));
return query(filterReqDto);
}
private void save(String entityId, String label, Map<String, Object> detail, String graph) {
String lockKey = GraphLockUtils.getVertexVectorKey(entityId, label);
try {
// 加锁
graphVectorWriteLock.acquireLock(lockKey, 60000L);
// 语料切分
List<String> contentList = vertexContent(label, detail, graph);
if (CollectionUtils.isEmpty(contentList)) {
return;
}
// 向量化
List<GraphVectorInfo> vectorInfoList = embeddings(entityId, label, contentList, graph);
// 删除旧向量
delete(entityId, label, graph);
// 写入向量
write(vectorInfoList);
} finally {
graphVectorWriteLock.releaseLock(lockKey);
}
}
private void save(String edgeId, String relName, String labelOut, String entityIdOut, String entityNameOut, String labelIn, String entityIdIn, String entityNameIn, Map<String, Object> detail, String graph) {
String lockKey = GraphLockUtils.getEdgeVectorKey(edgeId);
try {
graphVectorWriteLock.acquireLock(lockKey, 60000L);
// 语料切分
List<String> contentList = edgeContent(labelOut, entityNameOut, relName, labelIn, entityNameIn, detail, graph);
if (CollectionUtils.isEmpty(contentList)) {
return;
}
// 向量化
List<GraphVectorInfo> vectorInfoList = embeddings(edgeId, relName, labelOut, entityIdOut, labelIn, entityIdIn, contentList, graph);
// 删除旧向量
delete(edgeId, graph);
// 写入向量
write(vectorInfoList);
} finally {
graphVectorWriteLock.releaseLock(lockKey);
}
}
private void delete(String entityId, String label, String graph) {
List<FieldFilter> fieldFilters = Lists.newArrayList();
fieldFilters.add(new FieldFilter("entity_id", FieldFilterTypeEnum.TERM, entityId));
fieldFilters.add(new FieldFilter("label", FieldFilterTypeEnum.TERM, label));
fieldFilters.add(new FieldFilter("graph", FieldFilterTypeEnum.TERM, graph));
delete(fieldFilters);
}
private void delete(String edgeId, String graph) {
List<FieldFilter> fieldFilters = Lists.newArrayList();
fieldFilters.add(new FieldFilter("edge_id", FieldFilterTypeEnum.TERM, edgeId));
fieldFilters.add(new FieldFilter("graph", FieldFilterTypeEnum.TERM, graph));
delete(fieldFilters);
}
@SuppressWarnings("unchecked")
private Map<String, Object> detail(String entityId, String label, String graph) {
VertexDetailDto detailDto = new VertexDetailDto();
detailDto.setAppId("internal");
detailDto.setUid("groot");
detailDto.setLabel(label);
detailDto.setEntityId(entityId);
detailDto.setExcludeFields(Sets.newHashSet("resume", "trail", "trail_info", "summary_list"));
detailDto.setGraph(graph);
return (Map<String, Object>) vertexFetchService.detail(detailDto);
}
private Map<String, Object> detail(String edgeId, String graph) {
EdgeCrudDto edgeCrudDto = new EdgeCrudDto();
edgeCrudDto.setAppId("internal");
edgeCrudDto.setUid("groot");
edgeCrudDto.setEdgeId(edgeId);
edgeCrudDto.setGraph(graph);
EdgeDetailResult result = edgeManageService.detail(edgeCrudDto);
if (result == null || result.getData() == null) {
return null;
}
return result.getData();
}
private void wrapReq(AbstractCommonRequest reqDto) {
reqDto.setProject("groot-vector-data");
reqDto.setIndexGroup("graph");
reqDto.setIndex("flag");
reqDto.setType("default");
}
private List<String> vertexContent(String label, Map<String, Object> detail, String graph) {
// 删除不需要的语料信息
for (String field : UNUSED_VERTEX_FIELD) {
detail.remove(field);
}
List<String> corpusList = Lists.newArrayList();
// 获取指定实体类型所有属性信息
VertexDesc desc = OntologyCacheManager.getInstance().getVertexDesc(graph, label);
Map<String, Pair<String, String>> properties = Maps.newHashMap();
getAllProperties(desc, properties);
// 对诗词内容特殊处理
String name = detail.get("name").toString();
if (detail.get("categories") != null
&& ((List<?>) detail.get("categories")).contains("poetry")) {
String summary = detail.containsKey("summary") ? detail.get("summary").toString() : null;
String content = detail.containsKey("content") ? detail.get("content").toString() : null;
if (StringUtils.isNotEmpty(summary) && StringUtils.isEmpty(content)) {
corpusList.add("诗词《" + name + "》内容是:" + summary);
detail.remove("summary");
} else if (StringUtils.isEmpty(summary) && StringUtils.isNotEmpty(content)) {
corpusList.add("诗词《" + name + "》内容是:" + content);
detail.remove("content");
} else if (StringUtils.isNotEmpty(summary) && StringUtils.isNotEmpty(content)) {
if (StringUtils.equals(summary, content)) {
corpusList.add("诗词《" + name + "》内容是:" + summary);
detail.remove("summary");
detail.remove("content");
} else {
corpusList.add("诗词《" + name + "》内容是:" + summary);
detail.remove("summary");
}
}
}
// 解析
JSONObject info = new JSONObject();
for (Map.Entry<String, Object> entry : detail.entrySet()) {
if (!properties.containsKey(entry.getKey())) {
continue;
}
Object value = entry.getValue();
Pair<String, String> propInfo = properties.get(entry.getKey());
String valueStr = convert(entry.getKey(), value, propInfo.getValue());
if (StringUtils.isNotEmpty(valueStr)) {
info.put(propInfo.getKey(), valueStr);
}
}
// 大模型切分语料
List<String> llmSplitCorpusList = llmSplitContent(info.toJSONString(), "JSON");
if (CollectionUtils.isNotEmpty(llmSplitCorpusList)) {
corpusList.addAll(llmSplitCorpusList);
}
return corpusList;
}
private List<String> edgeContent(String labelOut, String entityNameOut, String relName, String labelIn, String entityNameIn, Map<String, Object> detail, String graph) {
// 删除不需要的语料信息
for (String field : UNUSED_EDGE_FIELD) {
detail.remove(field);
}
List<String> sentenceList = Lists.newArrayList();
// 针对相关和社会关系特殊处理
if (StringUtils.equalsAny(relName, "相关", "社会关系")) {
String description = null;
if (detail.containsKey("description")) {
description = detail.remove("description").toString();
}
String originName = null;
if (detail.containsKey("origin_name")) {
originName = detail.remove("origin_name").toString();
}
String displayRelName = null;
if (detail.containsKey("display_rel_name")) {
displayRelName = detail.remove("display_rel_name").toString();
}
String transRelationName = null;
if (detail.containsKey("biz_info")) {
JSONObject bizInfo = (JSONObject) detail.remove("biz_info");
if (bizInfo.containsKey("default")) {
JSONObject defaultInfo = (JSONObject) bizInfo.get("default");
if (defaultInfo.containsKey("trans_relation_name")) {
transRelationName = defaultInfo.getString("trans_relation_name");
}
}
}
// 若原始关系名、描述、展示关系名、转义关系名为空,则该数据无用
if (StringUtils.isAllEmpty(originName, description, displayRelName, transRelationName)) {
return Lists.newArrayList();
}
boolean hasTransRelation = false;
if (StringUtils.isNotEmpty(transRelationName)) {
sentenceList.add(transRelationName);
hasTransRelation = true;
}
if (StringUtils.isNotEmpty(originName)) {
if (hasTransRelation) {
sentenceList.add("该关系的原始关系名为:" + originName);
} else {
String sentence = ContentToStrUtils.wrapName(entityNameOut, labelOut)
+ "" + ContentToStrUtils.wrapName(entityNameIn, labelIn)
+ originName;
sentenceList.add(sentence);
}
}
if (StringUtils.isNotEmpty(displayRelName)) {
if (hasTransRelation) {
sentenceList.add("该关系的展示关系名为:" + originName);
} else {
String sentence = ContentToStrUtils.wrapName(entityNameOut, labelOut)
+ "" + ContentToStrUtils.wrapName(entityNameIn, labelIn)
+ displayRelName;
sentenceList.add(sentence);
}
}
if (StringUtils.isNotEmpty(description)) {
if (CollectionUtils.isNotEmpty(sentenceList)) {
sentenceList.add("该关系的关系描述为:" + description);
} else {
sentenceList.add(ContentToStrUtils.wrapName(entityNameOut, labelOut) + "" + ContentToStrUtils.wrapName(entityNameIn, labelIn) + "有关系," + description);
}
}
} else {
// 解析文本
boolean contentPair = PAIR_REL_NAME.contains(relName);
String sentence;
if (contentPair) {
sentence = ContentToStrUtils.wrapName(entityNameOut, labelOut)
+ "" + ContentToStrUtils.wrapName(entityNameIn, labelIn)
+ relName;
} else {
sentence = ContentToStrUtils.wrapName(entityNameOut, labelOut)
+ "" + relName + "" +
ContentToStrUtils.wrapName(entityNameIn, labelIn);
}
sentenceList.add(sentence);
}
if (detail.isEmpty()) {
// 大模型切分语料
return llmSplitContent(Joiner.on("").join(sentenceList), "文本");
}
// 获取指定关系类型所有属性信息
RelationDesc relationDesc = OntologyCacheManager.getInstance().getRelationDesc(graph, GraphConstants.RelIndex.EDGE_INDEX_OPEN_RELATION, relName, labelOut, labelIn);
List<PropDesc> propDescList = relationDesc.getPropertyList();
Map<String, Pair<String, String>> properties = Maps.newHashMap();
getAllProperties(propDescList, properties);
for (Map.Entry<String, Object> entry : detail.entrySet()) {
if (!properties.containsKey(entry.getKey())) {
continue;
}
Object value = entry.getValue();
Pair<String, String> propInfo = properties.get(entry.getKey());
String valueStr = convert(entry.getKey(), value, propInfo.getValue());
String propSentence = "该关系的" + propInfo.getKey() + "为:" + valueStr;
sentenceList.add(propSentence);
}
// 大模型切分语料
return llmSplitContent(Joiner.on("").join(sentenceList), "文本");
}
private void getAllProperties(VertexDesc desc, Map<String, Pair<String, String>> properties) {
if (CollectionUtils.isEmpty(desc.getProperties())) {
return;
}
for (PropDesc propDesc : desc.getProperties()) {
String propName = propDesc.getName();
properties.put(propName, Pair.of(propDesc.getDesc(), propDesc.getFieldType()));
}
if (CollectionUtils.isEmpty(desc.getSubClasses())) {
for (SubClassDesc subClassDesc : desc.getSubClasses()) {
getAllProperties(subClassDesc, properties);
}
}
}
private void getAllProperties(SubClassDesc desc, Map<String, Pair<String, String>> properties) {
if (CollectionUtils.isEmpty(desc.getPropertyList())) {
return;
}
for (PropDesc propDesc : desc.getPropertyList()) {
String propName = BooleanUtils.isTrue(desc.getPrefix()) ? desc.getName() + "." + propDesc.getName() : propDesc.getName();
String chnName = desc.getDesc() + "-" + propDesc.getDesc();
properties.put(propName, Pair.of(chnName, propDesc.getFieldType()));
}
if (CollectionUtils.isEmpty(desc.getSubClassList())) {
for (SubClassDesc subClassDesc : desc.getSubClassList()) {
getAllProperties(subClassDesc, properties);
}
}
}
private void getAllProperties(List<PropDesc> propDescList, Map<String, Pair<String, String>> properties) {
if (CollectionUtils.isEmpty(propDescList)) {
return;
}
for (PropDesc propDesc : propDescList) {
String propName = propDesc.getName();
String chnName = propDesc.getDesc();
properties.put(propName, Pair.of(chnName, propDesc.getFieldType()));
}
}
@SuppressWarnings("unchecked")
private String convert(String key, Object value, String fieldType) {
switch (fieldType) {
case "Integer":
case "Long":
case "Double":
case "String":
return String.valueOf(value);
case "HashMap":
if (key.contains("address")) {
Map<String, Object> addrMap = (Map<String, Object>) value;
if (addrMap.containsKey("address")) {
return addrMap.get("address").toString();
}
String addr = "";
if (addrMap.get("province") != null) {
addr += addrMap.get("province").toString();
}
if (addrMap.get("city") != null) {
addr += addrMap.get("city").toString();
}
if (addrMap.get("county") != null) {
addr += addrMap.get("county").toString();
}
return addr;
} else {
return null;
}
default:
return null;
}
}
private List<String> llmSplitContent(String info, String textType) {
String systemPrompt = GraphVectorParamsConfig.getSplitPrompt();
systemPrompt = systemPrompt.replace("{TEXT_TYPE}", textType);
LLMMessage message = new GPTMessage("system", systemPrompt);
String prompt = "JSON文本是<" + info + ">";
Flux<String> response = llmService.streamChat(LLMType.DeepSeek, prompt, Lists.newArrayList(message), new JSONObject(), 60000L);
String result = LLMStreamingUtils.blockGet(response);
return convertContentList(result);
}
private List<String> convertContentList(String result) {
if (StringUtils.isEmpty(result)) {
return Lists.newArrayList();
}
result = NoSolutionUtils.preProcessJSON(result);
result = NoSolutionUtils.preProcessPartialObject(result);
List<String> contentList = Lists.newArrayList();
try {
// 为了保证jsonArray的字段顺序
JSONArray structArray = JSONArray.parseObject(result, JSONArray.class);
for (int i = 0; i < structArray.size(); i++) {
JSONObject row = structArray.getJSONObject(i);
if (row.containsKey("text")) {
contentList.add(row.getString("text"));
}
}
} catch (Exception e) {
log.error("解析llm模版分析结果错误:{}", ExceptionUtils.getStackTrace(e));
}
return contentList;
}
private List<GraphVectorInfo> embeddings(String entityId, String label, List<String> contentList, String graph) {
List<GraphVectorInfo> vectorInfoList = Lists.newArrayList();
List<List<String>> batchList = Lists.partition(contentList, GraphVectorParamsConfig.getGenerateVectorBatchSize());
for (List<String> batch : batchList) {
List<List<BigDecimal>> features = embeddingService.embeddings(batch);
for (int i = 0; i < features.size(); i++) {
VertexVectorInfo vertexVectorInfo = new VertexVectorInfo();
vertexVectorInfo.setEntityId(entityId);
vertexVectorInfo.setLabel(label);
vertexVectorInfo.setContent(batch.get(i));
vertexVectorInfo.setFeature(features.get(i));
vertexVectorInfo.setGraph(graph);
vectorInfoList.add(vertexVectorInfo);
}
}
return vectorInfoList;
}
private List<GraphVectorInfo> embeddings(String edgeId, String relName, String labelOut, String entityIdOut, String labelIn, String entityIdIn, List<String> contentList, String graph) {
List<GraphVectorInfo> vectorInfoList = Lists.newArrayList();
List<List<String>> batchList = Lists.partition(contentList, GraphVectorParamsConfig.getGenerateVectorBatchSize());
for (List<String> batch : batchList) {
List<List<BigDecimal>> features = embeddingService.embeddings(batch);
for (int i = 0; i < features.size(); i++) {
EdgeVectorInfo edgeVectorInfo = new EdgeVectorInfo();
edgeVectorInfo.setEdgeId(edgeId);
edgeVectorInfo.setRelName(relName);
edgeVectorInfo.setLabelOut(labelOut);
edgeVectorInfo.setEntityIdOut(entityIdOut);
edgeVectorInfo.setLabelIn(labelIn);
edgeVectorInfo.setEntityIdIn(entityIdIn);
edgeVectorInfo.setContent(batch.get(i));
edgeVectorInfo.setFeature(features.get(i));
edgeVectorInfo.setGraph(graph);
vectorInfoList.add(edgeVectorInfo);
}
}
return vectorInfoList;
}
private void write(List<GraphVectorInfo> vectorInfoList) {
BulkReqDto bulkReqDto = new BulkReqDto();
wrapReq(bulkReqDto);
Map<String, OperateBulk> operateBulkMap = Maps.newHashMap();
OperateBulk operateBulk = new OperateBulk();
for (GraphVectorInfo vectorInfo : vectorInfoList) {
JSONObject doc = JSONObject.parseObject(JSONObject.toJSONString(vectorInfo));
operateBulk.insertAdd(IDUtils.getUUID(), doc);
}
operateBulkMap.put(bulkReqDto.getIndex(), operateBulk);
bulkReqDto.setOperateBulk(operateBulkMap);
BulkRespDto bulkRes = bulkService.bulk(bulkReqDto);
if (bulkRes == null || !bulkRes.getSucceed()) {
throw new GraphServiceException(ReturnCodeEnum.DB_ERROR, "save es failure");
}
}
private void delete(List<FieldFilter> fieldFilters) {
FilterReqDto filterReqDto = new FilterReqDto();
wrapReq(filterReqDto);
filterReqDto.setFilterFields(fieldFilters);
ItemsRespDto respDto = filterService.filter(filterReqDto);
if (respDto == null) {
return;
}
if (!respDto.getSucceed()) {
throw new QueueServiceException(QueueTypeEnum.MQ, respDto.getMessage());
}
if (respDto.getData() == null) {
return;
}
long predictCount = respDto.getData().getCount();
DeleteByQueryReqDto deleteByQueryReqDto = new DeleteByQueryReqDto();
wrapReq(deleteByQueryReqDto);
deleteByQueryReqDto.setFilterFields(fieldFilters);
deleteByQueryReqDto.setPredictCount(predictCount);
ItemsRespDto deleteByQueryRespDto = filterService.deleteByQuery(deleteByQueryReqDto);
if (!deleteByQueryRespDto.getSucceed()) {
throw new QueueServiceException(QueueTypeEnum.MQ, respDto.getMessage());
}
}
private BoolQuery wrapQuery(VectorSearchDto dto, List<BigDecimal> queryFeature, int size) {
BoolQuery boolQuery = new BoolQuery();
if (StringUtils.isNotBlank(dto.getLabel())) {
boolQuery.filter(new FieldFilter("label", FieldFilterTypeEnum.TERM, dto.getLabel()));
}
if (StringUtils.isNotBlank(dto.getEntityId())) {
boolQuery.filter(new FieldFilter("entity_id", FieldFilterTypeEnum.TERM, dto.getEntityId()));
}
if (StringUtils.isNotBlank(dto.getEdgeId())) {
boolQuery.filter(new FieldFilter("edge_id", FieldFilterTypeEnum.TERM, dto.getEdgeId()));
}
if (StringUtils.isNotBlank(dto.getRelName())) {
boolQuery.filter(new FieldFilter("rel_name", FieldFilterTypeEnum.TERM, dto.getRelName()));
}
if (StringUtils.isNotBlank(dto.getLabelOut())) {
boolQuery.filter(new FieldFilter("label_out", FieldFilterTypeEnum.TERM, dto.getLabelOut()));
}
if (StringUtils.isNotBlank(dto.getEntityIdOut())) {
boolQuery.filter(new FieldFilter("entity_id_out", FieldFilterTypeEnum.TERM, dto.getEntityIdOut()));
}
if (StringUtils.isNotBlank(dto.getLabelIn())) {
boolQuery.filter(new FieldFilter("label_in", FieldFilterTypeEnum.TERM, dto.getLabelIn()));
}
if (StringUtils.isNotBlank(dto.getEntityIdIn())) {
boolQuery.filter(new FieldFilter("entity_id_in", FieldFilterTypeEnum.TERM, dto.getEntityIdIn()));
}
boolQuery.filter(new FieldFilter("graph", FieldFilterTypeEnum.TERM, dto.getGraph()));
// vector query
float[] vectorArray = new float[queryFeature.size()];
for (int val = 0; val < queryFeature.size(); val++) {
vectorArray[val] = queryFeature.get(val).floatValue();
}
KNNVector knnVector = new KNNVector(vectorArray, size);
boolQuery.must(new FieldFilter("feature", FieldFilterTypeEnum.KNN, knnVector));
return boolQuery;
}
private List<GraphVectorInfo> query(FilterReqDto reqDto) {
ItemsRespDto respDto = filterService.filter(reqDto);
if (!respDto.getSucceed()) {
throw new GraphServiceException(ReturnCodeEnum.DB_ERROR, respDto.getMessage());
}
if (respDto.getData() == null) {
throw new GraphServiceException(ReturnCodeEnum.DB_ERROR, respDto.getMessage());
}
return respDto.getData().getItems().stream()
.map(item -> {
JSONObject jsonItem = JSONObject.parseObject(JSON.toJSONString(item));
String vectorDataType = jsonItem.getString("vector_data_type");
if (vectorDataType == null) {
return null;
}
if (vectorDataType.equals("VERTEX")) {
VertexVectorResult result = jsonItem.toJavaObject(VertexVectorResult.class);
Map<String, Object> entity = detail(result.getEntityId(), result.getLabel(), result.getGraph());
if (entity == null) {
return null;
}
result.setEntity(entity);
return result;
} else if (vectorDataType.equals("EDGE")) {
EdgeVectorResult result = jsonItem.toJavaObject(EdgeVectorResult.class);
Map<String, Object> edge = detail(result.getEdgeId(), result.getGraph());
if (edge == null) {
return null;
}
result.setEdge(edge);
Map<String, Object> source = detail(result.getEntityIdOut(), result.getLabelOut(), result.getGraph());
if (source == null) {
return null;
}
result.setSource(source);
Map<String, Object> target = detail(result.getEntityIdIn(), result.getLabelIn(), result.getGraph());
if (target == null) {
return null;
}
result.setTarget(target);
return result;
}
return null;
})
.filter(Objects::nonNull).collect(Collectors.toList());
}
}

View File

@@ -15,6 +15,7 @@ import com.shuwen.data.entity.manage.api.model.graph.modify.VertexMergeResult;
import com.shuwen.data.entity.manage.api.model.graph.modify.VertexUpdateResult;
import com.shuwen.data.entity.manage.api.service.IVertexManageService;
import com.shuwen.data.entity.manage.common.utils.AssertParamUtils;
import com.shuwen.data.entity.manage.service.external.mq.ProducerHandler;
import com.shuwen.data.entity.manage.service.integration.configmap.SwitchConfig;
import com.shuwen.data.entity.manage.service.internal.AuthorityUtils;
import com.shuwen.data.entity.manage.service.internal.ICheckService;
@@ -96,6 +97,9 @@ public class VertexManageService implements IVertexManageService {
@Resource
private IGraphEntityResumeService graphEntityResumeService;
@Resource
private ProducerHandler producerHandler;
@Value("${deploy.mode}")
private String deployMode;
@@ -178,6 +182,10 @@ public class VertexManageService implements IVertexManageService {
} finally {
// graphWriteLock.vertexRelease(lockKey);
}
// 发送MQ
producerHandler.sendVertex(entityId, label, graph);
return RespDto.succeed(result);
}
@@ -238,6 +246,10 @@ public class VertexManageService implements IVertexManageService {
if (CollectionUtils.isNotEmpty(modifyField)) {
result.setUpdateTime(PropUtils.getPropLong(vertex, PROP_UPDATE_TIME));
}
// 发送MQ
producerHandler.sendVertex(dto.getEntityId(), label, graph);
return RespDto.succeed(result);
}
@@ -318,6 +330,10 @@ public class VertexManageService implements IVertexManageService {
}
VertexUpdateResult result = new VertexUpdateResult(dto.getEntityId(), modifyField);
result.setUpdateTime(updateTime);
// 发送MQ
producerHandler.sendVertex(dto.getEntityId(), label, graph);
return RespDto.succeed(result);
}
@@ -376,6 +392,10 @@ public class VertexManageService implements IVertexManageService {
update(graph, label, dto.getSourceId(), source, null, sourceModifyField, new HashSet<>(), ModifyActionEnum.DELETE);
VertexMergeResult result = new VertexMergeResult(dto.getSourceId(), dto.getTargetId(), true);
// 发送MQ
producerHandler.sendVertex(dto.getTargetId(), label, graph);
return RespDto.succeed(result);
}

View File

@@ -29,6 +29,7 @@ public class GraphRecallParamsConfig {
private static final String KEY_COMPLEX_ADDRESS_PROPERTY_LIST = "complex_address_property_list";
private static final String KEY_MAX_CORPUS_LENGTH = "max_corpus_length";
private static final String KEY_FULL_TEXT_THRESHOLD = "full_text_threshold";
private static final String KEY_VECTOR_THRESHOLD = "vector_threshold";
private static final String KEY_CORPUS_TAG_LABEL_WHITELIST = "corpus_tag_label_whitelist";
private static final String KEY_CORPUS_TAG_BLACKLIST = "corpus_tag_blacklist";
private static final String KEY_CORPUS_HAS_IMAGE = "corpus_has_image";
@@ -54,6 +55,8 @@ public class GraphRecallParamsConfig {
private static Float fullTextThreshold;
private static Float vectorThreshold;
private static Set<String> corpusTagLabelWhitelist;
private static Set<String> corpusTagBlacklist;
@@ -96,6 +99,9 @@ public class GraphRecallParamsConfig {
if (configJson.containsKey(KEY_FULL_TEXT_THRESHOLD)) {
fullTextThreshold = configJson.getFloat(KEY_FULL_TEXT_THRESHOLD);
}
if (configJson.containsKey(KEY_VECTOR_THRESHOLD)) {
vectorThreshold = configJson.getFloat(KEY_VECTOR_THRESHOLD);
}
if (configJson.containsKey(KEY_CORPUS_TAG_LABEL_WHITELIST)) {
corpusTagLabelWhitelist = Sets.newHashSet(configJson.getJSONArray(KEY_CORPUS_TAG_LABEL_WHITELIST).toJavaList(String.class));
}
@@ -214,6 +220,13 @@ public class GraphRecallParamsConfig {
return 0.75f;
}
public static Float getVectorThreshold() {
if (vectorThreshold != null) {
return vectorThreshold;
}
return 0.75f;
}
public static boolean inInCorpusTagLabelWhitelist(String label) {
if (CollectionUtils.isEmpty(corpusTagLabelWhitelist)) {
return false;

View File

@@ -0,0 +1,53 @@
package com.shuwen.data.entity.manage.service.integration.configmap;
import com.alibaba.fastjson.JSONObject;
import com.shuwen.ops.shaman.configmap.Config;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 11:07
*/
@Slf4j
public class GraphVectorParamsConfig {
private static final String DATA_ID = "graph_vector_params";
private static final String KEY_SPLIT_PROMPT = "split_prompt";
private static final String KEY_GENERATE_VECTOR_BATCH_SIZE = "generate_vector_batch_size";
static {
Config.addListener(DATA_ID, GraphVectorParamsConfig::analyzeData);
}
@Getter
private static String splitPrompt;
private static int generateVectorBatchSize;
private static synchronized void analyzeData(String configInfo) {
log.info("get llm params config: {}", configInfo);
JSONObject configJson = JSONObject.parseObject(configInfo);
if (configInfo == null) {
return;
}
if (configJson.containsKey(KEY_SPLIT_PROMPT)) {
splitPrompt = configJson.getString(KEY_SPLIT_PROMPT);
}
if (configJson.containsKey(KEY_GENERATE_VECTOR_BATCH_SIZE)) {
generateVectorBatchSize = configJson.getInteger(KEY_GENERATE_VECTOR_BATCH_SIZE);
}
}
public static int getGenerateVectorBatchSize() {
if (generateVectorBatchSize <= 0) {
return generateVectorBatchSize;
} else {
return 5;
}
}
}

View File

@@ -34,4 +34,12 @@ public class GraphLockUtils {
String key = label + ":" + entityId;
return key;
}
public static String getVertexVectorKey(String label, String entityId) {
return "graph:vertex:vector:" + label + ":" + entityId;
}
public static String getEdgeVectorKey(String edgeId) {
return "graph:edge:vector:" + edgeId;
}
}

View File

@@ -0,0 +1,22 @@
package com.shuwen.data.entity.manage.service.lock;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 16:54
*/
public interface IGraphVectorWriteLock {
/**
* 阻塞式加锁
* @param key 需要加锁的Key
* @param expireInMillis 过期时间
*/
void acquireLock(String key, Long expireInMillis);
/**
* 释放锁
* @param key 需要释放锁的Key
*/
void releaseLock(String key);
}

View File

@@ -0,0 +1,34 @@
package com.shuwen.data.entity.manage.service.lock.impl;
import com.shuwen.data.entity.manage.service.lock.IGraphVectorWriteLock;
import org.redisson.api.RLock;
import org.redisson.api.RedissonClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 16:55
*/
@Service
public class GraphVectorWriteLock implements IGraphVectorWriteLock {
@Autowired
private RedissonClient redissonClient;
@Override
public void acquireLock(String key, Long expireInMillis) {
RLock lock = redissonClient.getLock(key);
lock.lock(expireInMillis, TimeUnit.MILLISECONDS);
}
@Override
public void releaseLock(String key) {
RLock lock = redissonClient.getLock(key);
lock.unlock();
}
}

View File

@@ -39,7 +39,7 @@ public class OntologyEdgeService implements IOntologyEdgeService {
@Override
public Map<String, PropDesc> getPropDescAll(String graph, InternalBaseEdge edge) {
Boolean checkRelStrict = SwitchConfig.checkRelStrict;
String method = EntityRequestContextUtils.get().getMethod();
String method = EntityRequestContextUtils.get() != null ? EntityRequestContextUtils.get().getMethod() : "detail";
List<PropDesc> propertyList = null;
String edgeIndex = edge.getRelType();
RelationDesc relationDesc = getRelationDefined(graph, edge);

View File

@@ -1,4 +1,4 @@
package com.shuwen.data.entity.manage.service.storage.extend;
package com.shuwen.data.entity.manage.service.storage.extend.impl;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.lang.Pair;
@@ -16,6 +16,7 @@ import com.shuwen.data.entity.manage.service.config.RecallConfig;
import com.shuwen.data.entity.manage.service.external.similarity.EmbeddingService;
import com.shuwen.data.entity.manage.service.ontology.enums.ModifyActionEnum;
import com.shuwen.data.entity.manage.service.storage.composite.impl.GraphVertexStorageService;
import com.shuwen.data.entity.manage.service.storage.extend.IGraphEntityResumeService;
import com.shuwen.search.proxy.api.entity.base.FieldFilter;
import com.shuwen.search.proxy.api.entity.base.FieldPhrase;
import com.shuwen.search.proxy.api.entity.base.FieldText;

View File

@@ -20,6 +20,7 @@ import java.io.Serializable;
public class PropDesc implements Serializable {
private Integer propertyId;
private String name;
private String desc;
private String fieldType;
private Boolean multivalued;
private Boolean index;

View File

@@ -39,6 +39,10 @@ public class DataGraph implements Serializable {
* 内部赋值: 根据请求结果填入
*/
private Boolean existed;
/**
* 是否向量化
*/
private boolean vectorization = false;
public Set<String> getAllLabels() {
if (CollectionUtils.isEmpty(classList)) {

View File

@@ -0,0 +1,37 @@
package com.shuwen.data.entity.manage.service.util;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.stream.Collectors;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 11:14
*/
public class LLMStreamingUtils {
private static final String LLM_STREAMING_TOKEN_KEY = "delta";
public static String blockGet(Flux<String> response) {
List<String> resultList = response.collectList().block();
if (CollectionUtils.isEmpty(resultList)) return null;
return resultList.stream().map(resultItem -> {
try {
JSONObject resultJson = JSONObject.parseObject(resultItem);
String delta = resultJson.getString(LLM_STREAMING_TOKEN_KEY);
if (delta == null) return "";
if (StringUtils.equals(delta, "null")) return "";
return delta;
} catch (Exception e) {
return "";
}
}).collect(Collectors.joining());
}
}

View File

@@ -0,0 +1,83 @@
package com.shuwen.data.entity.manage.service.util;
import com.alibaba.fastjson.JSONArray;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonToken;
import org.apache.commons.lang3.StringUtils;
import java.io.StringReader;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Set;
/**
* Project: entity-manage
* Description:
* Author: Kenn
* Create: 2025/3/31 11:16
*/
public class NoSolutionUtils {
public static String preProcessJSON(String result) {
if (StringUtils.contains(result, "```json")) {
String[] resultGroup = result.split("```json");
if (resultGroup.length > 1) {
result = resultGroup[1];
}
result = result.replaceAll("```json", "").replaceAll("```", "");
}
return result.trim();
}
public static String preProcessPartialObject(String result) {
try {
if (!StringUtils.startsWith(result, "[")) return result;
JsonReader jsonReader = new JsonReader(new StringReader(result));
JSONArray jsonArray = new JSONArray();
try {
jsonReader.beginArray();
while (jsonReader.hasNext()) {
HashMap<String, Object> jsonObject = new LinkedHashMap<>();
jsonReader.beginObject();
while (jsonReader.hasNext()) {
String name = jsonReader.nextName();
JsonToken token = jsonReader.peek();
if (token == JsonToken.NULL) {
jsonObject.put(name, "");
jsonReader.skipValue();
} else if (token == JsonToken.BEGIN_OBJECT) {
jsonReader.skipValue();
} else if (token == JsonToken.BEGIN_ARRAY) {
jsonReader.beginArray();
Set<String> stringSet = new LinkedHashSet<>();
while (jsonReader.hasNext()) {
JsonToken itemToken = jsonReader.peek();
if (itemToken == JsonToken.STRING) {
String str = jsonReader.nextString();
stringSet.add(str);
}
}
String value = StringUtils.join(stringSet, ",");
jsonObject.put(name, value);
jsonReader.endArray();
}else {
jsonObject.put(name, jsonReader.nextString());
}
}
jsonReader.endObject();
jsonArray.add(jsonObject);
}
} catch (Exception e) {
try {
jsonReader.endArray();
} catch (Exception ignored) {}
}
return jsonArray.toString();
} catch (Exception e) {
return result;
}
}
}

17
pom.xml
View File

@@ -27,7 +27,7 @@
<shuwen.search-proxy.version>1.1.11</shuwen.search-proxy.version>
<shuwen.verdant-utils.version>0.1.12</shuwen.verdant-utils.version>
<shuwen.configmap.version>1.1.1</shuwen.configmap.version>
<shuwen.llm.version>0.0.16-SNAPSHOT</shuwen.llm.version>
<shuwen.llm.version>0.0.21</shuwen.llm.version>
<shuwen.gss.version>1.3.2-SNAPSHOT</shuwen.gss.version>
<shuwen.lexical-segment.version>0.0.4-SNAPSHOT</shuwen.lexical-segment.version>
@@ -35,6 +35,8 @@
<aliyun.ons.version>1.8.0.Final</aliyun.ons.version>
<aliyun.oss.version>3.6.0</aliyun.oss.version>
<rocketmq.version>4.9.8</rocketmq.version>
<dubbo.version>2.7.23</dubbo.version>
<curator.version>4.2.0</curator.version>
<zookeeper.version>3.4.14</zookeeper.version>
@@ -137,6 +139,19 @@
</dependency>
<!-- aliyun end -->
<!-- service end -->
<dependency>
<groupId>org.apache.rocketmq</groupId>
<artifactId>rocketmq-client</artifactId>
<version>${rocketmq.version}</version>
</dependency>
<dependency>
<groupId>org.apache.rocketmq</groupId>
<artifactId>rocketmq-acl</artifactId>
<version>${rocketmq.version}</version>
</dependency>
<!-- service end -->
<!-- dubbo start -->
<dependency>
<groupId>org.apache.dubbo</groupId>