feat:新向量方案接入

This commit is contained in:
zhujinkai
2025-03-25 15:20:39 +08:00
parent 3e077e1073
commit ef8f1a20d5
18 changed files with 685 additions and 33 deletions

View File

@@ -0,0 +1,22 @@
package com.shuwen.groot.common.utils;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
/**
* 类注释
*
* @Author: zhujinkai
* @Date: 2025/03/25/10:05
*/
public class FutureUtils {
public static <T> List<T> get(List<CompletableFuture<T>> futures) {
return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]))
.thenApply(v -> futures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList()))
.join();
}
}

View File

@@ -110,6 +110,10 @@
<groupId>com.shuwen.mediax</groupId>
<artifactId>ram-client</artifactId>
</dependency>
<dependency>
<groupId>com.shuwen.mid</groupId>
<artifactId>mid-sdk</artifactId>
</dependency>
<!-- xhzy end -->
<!-- aliyun start -->

View File

@@ -1,6 +1,7 @@
package com.shuwen.groot.manager.http;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.TypeReference;
import com.shuwen.groot.common.enums.InternalErrorCode;
import com.shuwen.groot.manager.constant.Constant;
import okhttp3.HttpUrl;
@@ -27,6 +28,18 @@ public abstract class BaseHttpHandler {
abstract OkHttpClient getClient();
public <T> T doPost(String url, String body, TypeReference<T> typeReference) {
return doPost(url, null, null, body, typeReference);
}
public <T> T doPost(String url, JSONObject queries, JSONObject headers, String body, TypeReference<T> typeReference) {
String response = doPost(url, queries, headers, body);
if (StringUtils.isEmpty(response)) {
return null;
}
return JSONObject.parseObject(response, typeReference);
}
public String doPost(String url, JSONObject queries, JSONObject headers, String body) {
String exactUrl = url(url, queries);
if (StringUtils.isEmpty(exactUrl)) {

View File

@@ -0,0 +1,43 @@
package com.shuwen.groot.service.dto.tdc;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
/**
* 类注释
*
* @Author: zhujinkai
* @Date: 2025/03/22/17:37
*/
@Data
@Accessors(chain = true)
public class GVectorParam {
/**
* 任务资源列表
*/
private List<TaskRes> resList;
/**
* 任务参数
*/
private GVectorTaskParam taskParam;
@Data
@Accessors(chain = true)
public static class GVectorTaskParam {
/**
* 图谱
*/
private String graph;
/*
* 租户ID
*/
private Long tenantId;
/**
* 用户ID
*/
private Long userId;
}
}

View File

@@ -0,0 +1,123 @@
package com.shuwen.groot.service.dto.tdc;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.Data;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotEmpty;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
/**
* @author winner
* @date 2024/4/2
*/
@Slf4j
@Data
@Accessors(chain = true)
public class TaskParam implements Serializable {
private static final long serialVersionUID = 2297841548860043876L;
/**
* 业务线标识长度不超10个英文字符<br/>
* 如media、magic、wenlv
*/
@NotBlank
private String bizCode;
/**
* 任务类型列表,任务提交时支持一个资源一次执行多种任务的组合 <br/>
* 对应具体算子服务的任务类型定义。如vf、ite、tem、ocr、asr、$r <br/>
* 其中 $r 为内部特殊任务类型,表示当前任务为集合的根任务,不作为参数的输入
*/
@NotEmpty
private List<String> taskType;
/**
* 任务优先级定义7级优先级<br/>
* 0 - emerg/紧急、10 - prior/优先、20 - high/高、30 - std/标准、40 - low/低、50 - lower/更低、60 - lowest/极低
*/
private Integer taskPrior;
/**
* 任务参数JSON格式<br/>
* 任务执行器参数,该参数可以直接透传给具体算子服务,具体由具体执行器的实现逻辑而定。<br/>
* 注意:对象第一层属于即为 task_type 任务类型而定,为特定任务类型的算子服务进行设置参数。<br/>
* 格式如下:
* <pre>
* {
* // ite向量服务参数
* "ite": {
* "threshold": 0.85, // 融合阈值
* "dim": 512, // 向量维度默认512
* "uploadBucket": "bkName", // 转存Bucket
* "uploadPath": "xxx/xxx", // 转存前缀路径,其前后不带 / 斜杠
* },
* // 截帧服务参数
* "vf": {
* "extend": "jpg", // 截帧图片扩展名
* "interval": 1000 // 截帧间隔频率
* },
* // TEM文搜文参数
* // TPA: https://xhzy.yuque.com/qbrrmo/project/adlz0bmboz1006i1
* // TEM: https://xhzy.yuque.com/qbrrmo/project/bgdpvn4z7obyek2w
* "tem": {
* "type": "txt", // 数据类型默认值txt可选值txt, asr, ocr
* "target_len": 800, // 文本拆分的目标大小默认800 一般以目标大小为中心上下浮动几十个字符
* "margin": 100, // 运行浮动的范围默认100, 浮动字数范围
* "dim": 1024 // 维度大小默认1024必须为8的整数倍最大值1024
* }
* }
* </pre>
*/
private Map<String, Object> taskParam;
/**
* 提交任务的透传数据
*/
private String taskMeta;
/**
* 资源类型video、audio、image、text、m3u8
*/
@NotBlank
private String resType;
/**
* 资源数据定义<br/>
* 支持单个可多个
*/
@NotEmpty
private List<TaskRes> resList;
/**
* 反馈参数JSON格式<br/>
* 如回调地址、MQ定义、接口定义等具体数据依任务类型对应的反馈模式而定
* <pre>
* {
* "type": "mq",
* "param": {
* "topic": "",
* "tag": ""
* }
* }
* </pre>
*/
private Map<String, Object> callback;
public String getTaskParamString() {
return JSON.toJSONString(this.taskParam);
}
public String getCallbackString() {
return this.callback == null ? null : JSON.toJSONString(this.callback);
}
public String getTaskParam(String key) {
if (taskParam == null) {
return null;
}
if (taskParam.containsKey(key)) {
return JSONObject.toJSONString(taskParam.get(key));
}
return null;
}
}

View File

@@ -0,0 +1,54 @@
package com.shuwen.groot.service.dto.tdc;
import cn.hutool.core.util.StrUtil;
import lombok.Data;
import lombok.experimental.Accessors;
import javax.validation.constraints.NotBlank;
/**
* @author winner
* @date 2024/4/3
*/
@Data
@Accessors(chain = true)
public class TaskRes {
/**
* 资源key上层业务唯一键该值不允许为空
*/
@NotBlank
private String key;
/**
* 资源对应的唯一 etag 信息<br/>
* 供后续生成 resHash 使用etag + 文件大小 作为唯一Hash标识<br/>
* 若 etag 为空默认采用对象存储API接口去获取
*/
private String etag;
/**
* 资源名称,便于任务展示时的阅读性
*/
private String name;
/**
* 任务描述,便于阅读性
*/
private String desc;
/**
* 资源地址,如:视频、图片等文件地址
*/
private String url;
/**
* 资源数据存储部分类型的资源数据text类型的文本数据
* tem 算法参数中的文本类型可选 txt/asr/ocr
* - 当是 txt 时其data数据为纯文本
* - 当是 asr/ocr 时其格式参见TPA文本预处理算法的入参https://xhzy.yuque.com/qbrrmo/project/adlz0bmboz1006i1
*/
private String data;
/**
* 是否有资源
* @return
*/
public boolean hasRes() {
return !StrUtil.isAllBlank(data, url);
}
}

View File

@@ -0,0 +1,91 @@
package com.shuwen.groot.service.dto.tdc;
import lombok.Data;
import lombok.experimental.Accessors;
import java.io.Serializable;
import java.util.List;
/**
* @author winner
* @date 2024/4/15
*/
@Data
@Accessors(chain = true)
public class TaskResult implements Serializable {
private static final long serialVersionUID = 4743388689760551714L;
/**
* 根任务唯一标识<br/>
* 当某任务是单一任务情况下,结果列表数据只会有一条数据,且根任务标识与与其第一条结果的任务标识相同
*/
private String rootKey;
/**
* 任务透传数据
*/
private String taskMeta;
/**
* 任务结果列表数据
*/
private List<ResultData> data;
@Data
@Accessors(chain = true)
public static class ResultData implements Serializable {
private static final long serialVersionUID = -2731052319432409122L;
/**
* 资源key上层业务唯一键
*/
private String resKey;
/**
* 任务标识,唯一键<br/>
* <pre>
* 格式如下:
* - 单一任务240329_100320_{biz}_{SSS}{random}
* - 复杂任务240329x100320_{biz}_{SSS}{random}
* </pre>
*/
private String taskKey;
/**
* 远程任务Key值VPC下存的是云端的任务外键Cloud下存的是私有化的任务外键
*/
private String outKey;
/**
* 任务类型,<br/>
* 对应具体算子服务的任务类型定义。如:$r、ite、ocr、asr <br/>
* 其中 $r 为特殊任务类型,表示当前任务为集合的根任务
*/
private String taskType;
/**
* 任务状态,记录当前任务执行的各阶段状态。定义如下:<br/>
* <pre>
* - init/新建:任务网添加成功的状态
* - finished/完结:正常完成执行
* - failed/失败:执行出现任何异常或意外而结束任务,即使进行重试操作。
* </pre>
*/
private String taskStatus;
/**
* 任务分片大小,这里数量的最小单位是算子服务的原子处理能力。比如一个 ITE 单张图片的向量计算即为一个分片。 <br/>
* 注意这里与任务执行器可接受100张图片同时提交的批量能力无关那里是任务执行器支持一调用处理100个分片能力。
*/
private int taskSize;
/**
* 执行成功的分片数量
*/
private int sliceSucc;
/**
* 执行错误的分片数量
*/
private int sliceErr;
/**
* 结果数据JSON格式<br/>
* 目前的结果只会返回结果文件的URL地址结果文件是算法任务的JSON结果比如一个视频的 ITE 算法结果是以JSONL数据格式的多张图片embedding结果。
* <pre>
* {
* "url": "xxx"
* }
* </pre>
*/
private String result;
}
}

View File

@@ -0,0 +1,23 @@
package com.shuwen.groot.service.dto.vector;
import com.shuwen.groot.service.dto.EmbeddingRetrievalInfo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.Accessors;
/**
* 类注释
*
* @Author: zhujinkai
* @Date: 2025/03/24/14:11
*/
@EqualsAndHashCode(callSuper = true)
@Data
@Accessors(chain = true)
public class GVSearchParam extends EmbeddingRetrievalInfo {
/**
* 数量
*/
private int size;
}

View File

@@ -0,0 +1,35 @@
package com.shuwen.groot.service.handler;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.TypeReference;
import com.shuwen.groot.api.dto.library.EmbeddingRetrievalResult;
import com.shuwen.groot.manager.http.HttpHandler;
import com.shuwen.groot.service.dto.vector.GVSearchParam;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.List;
/**
* 类注释
*
* @Author: zhujinkai
* @Date: 2025/03/22/18:10
*/
@Component
public class GrootVectorHandler {
private static final String API_EMBEDDING_SEARCH = "/groot/vector/search";
@Value("${spring.restful.groot-vector}")
private String baseUrl;
@Resource
private HttpHandler httpHandler;
public List<EmbeddingRetrievalResult> search(GVSearchParam param) {
String url = baseUrl + API_EMBEDDING_SEARCH;
return httpHandler.doPost(url, JSON.toJSONString(param), new TypeReference<List<EmbeddingRetrievalResult>>(){});
}
}

View File

@@ -0,0 +1,66 @@
package com.shuwen.groot.service.handler;
import cn.hutool.core.bean.BeanUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.TypeReference;
import com.shuwen.groot.manager.http.HttpHandler;
import com.shuwen.groot.service.dto.tdc.GVectorParam;
import com.shuwen.groot.service.dto.tdc.TaskParam;
import com.shuwen.groot.service.dto.tdc.TaskResult;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.Collections;
/**
* tdc(任务调度中心)处理器
*
* @Author: zhujinkai
* @Date: 2025/03/22/17:20
*/
@Component
public class TdcHandler {
/**
* 添加任务
*/
private static final String API_ADD_TASK = "/api/task/add";
private static final String GROOT_VEC_LIB = "groot_vec_lib";
@Value("${spring.restful.tdc}")
private String tdcBaseUrl;
@Resource
private HttpHandler httpHandler;
/**
* 添加任务
* @param taskParam 任务参数
* @return 任务结果
*/
public TaskResult addTask(TaskParam taskParam) {
String url = tdcBaseUrl + API_ADD_TASK;
// 发起请求
return httpHandler.doPost(url, null, null,
JSON.toJSONString(taskParam), new TypeReference<TaskResult>() {});
}
/**
* 格物向量化
* @param param 参数
* @return 任务key
*/
public String grootVector(GVectorParam param) {
TaskParam taskParam = new TaskParam()
.setBizCode("media")
.setResList(param.getResList())
.setResType("text")
.setTaskParam(new JSONObject().fluentPut(GROOT_VEC_LIB, BeanUtil.beanToMap(param.getTaskParam())))
.setTaskType(Collections.singletonList(GROOT_VEC_LIB));
// 返回任务key
return addTask(taskParam).getRootKey();
}
}

View File

@@ -1,5 +1,6 @@
package com.shuwen.groot.service.handler.search;
import cn.hutool.core.bean.BeanUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
@@ -7,15 +8,19 @@ import com.google.common.collect.Maps;
import com.shuwen.groot.api.dto.library.EmbeddingRetrievalResult;
import com.shuwen.groot.api.dto.library.ExplainAnalysisItem;
import com.shuwen.groot.api.dto.library.LibraryVector;
import com.shuwen.groot.api.dto.request.library.EmbeddingRetrievalLibraryRequest;
import com.shuwen.groot.api.enums.LibraryVectorLevel;
import com.shuwen.groot.api.enums.LibraryVectorType;
import com.shuwen.groot.common.enums.InternalErrorCode;
import com.shuwen.groot.common.exception.DataBankException;
import com.shuwen.groot.common.utils.FutureUtils;
import com.shuwen.groot.manager.constant.Constant;
import com.shuwen.groot.service.context.ChatContext;
import com.shuwen.groot.service.dto.EmbeddingRetrievalInfo;
import com.shuwen.groot.service.dto.vector.GVSearchParam;
import com.shuwen.groot.service.enums.ChatRecallType;
import com.shuwen.groot.service.handler.BaseSearchHandler;
import com.shuwen.groot.service.handler.GrootVectorHandler;
import com.shuwen.groot.service.parser.LibraryTextParser;
import com.shuwen.groot.service.parser.LibraryVectorParser;
import com.shuwen.search.proxy.api.entity.base.BoolQuery;
@@ -85,6 +90,8 @@ public class LibraryVectorSearchHandler extends BaseSearchHandler {
@Resource
private LibraryTextSearchHandler libraryTextSearchHandler;
@Resource
private GrootVectorHandler grootVectorHandler;
public void add(LibraryVector vector, LibraryVectorType type) {
CrudReqDto crudReqDto = crud(vector.getId(), type);
@@ -147,6 +154,25 @@ public class LibraryVectorSearchHandler extends BaseSearchHandler {
return merge(itemsRespDtoList);
}
public List<EmbeddingRetrievalResult> newSearch(EmbeddingRetrievalLibraryRequest request, EmbeddingRetrievalInfo retrievalInfo) {
List<CompletableFuture<List<EmbeddingRetrievalResult>>> futures = Lists.newArrayList();
// 向量召回
futures.add(CompletableFuture.supplyAsync(() -> {
GVSearchParam param = BeanUtil.copyProperties(retrievalInfo, GVSearchParam.class).setSize(request.getSize());
return grootVectorHandler.search(param);
}, recallExecutor));
// 全文检索召回
if (request.isNeedBaseLibrary()) {
futures.add(CompletableFuture.supplyAsync(() -> convert2ERResult(baseLibrarySearch(retrievalInfo, request.getSize())), recallExecutor));
}
// 等待线程池返回结果
return FutureUtils.get(futures)
.stream()
.flatMap(List::stream)
.sorted(new EmbeddingRetrievalResult.RecallComparator())
.collect(Collectors.toList());
}
private CrudReqDto crud(String id, LibraryVectorType type) {
CrudReqDto crudReqDto = new CrudReqDto();
crudReqDto.setProject(VECTOR_INDEX_PROJECT);
@@ -206,25 +232,7 @@ public class LibraryVectorSearchHandler extends BaseSearchHandler {
}
if (needBaseLibrary) {
futureList.add(CompletableFuture.supplyAsync(() -> {
FilterReqDto filterReqDto = new FilterReqDto();
filterReqDto.setRequestId(retrievalInfo.getRequestId());
filterReqDto.setProject(INDEX_PROJECT);
filterReqDto.setIndexGroup(INDEX_GROUP);
filterReqDto.setIndices(Lists.newArrayList("page", "section"));
filterReqDto.setSize(size);
BoolQuery boolQuery = LibraryTextParser.parse(retrievalInfo.getCleanQuestion(), retrievalInfo.getKeywords(), retrievalInfo.getDatasetIdList(), retrievalInfo.getGraph());
filterReqDto.must(new FieldFilter(FieldFilterTypeEnum.BOOL, boolQuery));
Highlight highlight = new Highlight();
highlight.setNumberOfFragments(1);
highlight.setFragmentSize(512);
Highlight.HighlightField highlightField = new Highlight.HighlightField("content");
highlight.setFields(Lists.newArrayList(highlightField));
filterReqDto.setHighlight(highlight);
return Pair.of(ChatRecallType.base, libraryTextSearchHandler.filter(filterReqDto, Constant.INIT_RETRY_TIME));
}, recallExecutor));
futureList.add(CompletableFuture.supplyAsync(() -> Pair.of(ChatRecallType.base, baseLibrarySearch(retrievalInfo, size)), recallExecutor));
}
for (EmbeddingRetrievalInfo.VectorInfo vectorInfo : retrievalInfo.getVectorInfoList()) {
@@ -264,6 +272,31 @@ public class LibraryVectorSearchHandler extends BaseSearchHandler {
return queryResult;
}
private ItemsRespDto baseLibrarySearch(EmbeddingRetrievalInfo retrievalInfo, int size) {
FilterReqDto filterReqDto = new FilterReqDto();
filterReqDto.setRequestId(retrievalInfo.getRequestId());
filterReqDto.setProject(INDEX_PROJECT);
filterReqDto.setIndexGroup(INDEX_GROUP);
filterReqDto.setIndices(Lists.newArrayList("page", "section"));
filterReqDto.setSize(size);
BoolQuery boolQuery = LibraryTextParser.parse(retrievalInfo.getCleanQuestion(), retrievalInfo.getKeywords(), retrievalInfo.getDatasetIdList(), retrievalInfo.getGraph());
filterReqDto.must(new FieldFilter(FieldFilterTypeEnum.BOOL, boolQuery));
Highlight highlight = new Highlight();
highlight.setNumberOfFragments(1);
highlight.setFragmentSize(512);
Highlight.HighlightField highlightField = new Highlight.HighlightField("content");
highlight.setFields(Lists.newArrayList(highlightField));
filterReqDto.setHighlight(highlight);
return libraryTextSearchHandler.filter(filterReqDto, Constant.INIT_RETRY_TIME);
}
private List<EmbeddingRetrievalResult> convert2ERResult(ItemsRespDto itemsRespDto) {
return itemsRespDto.getData().getItems().stream()
.map(v -> JSONObject.parseObject(JSON.toJSONString(v), EmbeddingRetrievalResult.class))
.collect(Collectors.toList());
}
@SuppressWarnings("unchecked")
private List<EmbeddingRetrievalResult> merge(List<Pair<ChatRecallType, ItemsRespDto>> itemsRespDtoList) {
Map<String, List<Pair<ChatRecallType, Map<String, Object>>>> recallResult = Maps.newHashMap();

View File

@@ -1,5 +1,6 @@
package com.shuwen.groot.service.impl;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
@@ -48,16 +49,17 @@ import com.shuwen.groot.service.ILibraryService;
import com.shuwen.groot.service.IMaterialService;
import com.shuwen.groot.service.context.ChatContext;
import com.shuwen.groot.service.dto.EmbeddingRetrievalInfo;
import com.shuwen.groot.service.handler.ChatHandler;
import com.shuwen.groot.service.handler.EmbeddingHandler;
import com.shuwen.groot.service.handler.GraphDataHandler;
import com.shuwen.groot.service.handler.SegmentHandler;
import com.shuwen.groot.service.handler.*;
import com.shuwen.groot.service.handler.search.LibraryTextSearchHandler;
import com.shuwen.groot.service.handler.search.LibraryVectorSearchHandler;
import com.shuwen.groot.service.log.ChatLog;
import com.shuwen.groot.service.log.RecallLog;
import com.shuwen.groot.service.mq.ProducerHandler;
import com.shuwen.groot.service.utils.LibraryOptionUtils;
import com.shuwen.mid.sdk.tools.config.ConfigNotifyValue;
import com.shuwen.mid.sdk.tools.duplicate.DuplicateService;
import com.shuwen.mid.sdk.tools.duplicate.vo.ArticleVO;
import com.shuwen.mid.sdk.tools.duplicate.vo.CompareResultVO;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.EnumUtils;
@@ -73,18 +75,14 @@ import xhzy.algo.engine.Term;
import javax.annotation.Resource;
import java.math.BigDecimal;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import static com.shuwen.groot.common.base.Preconditions.checkNotEmpty;
import static com.shuwen.groot.common.base.Preconditions.checkNotNull;
import static com.shuwen.mid.sdk.tools.duplicate.vo.SuggestionResult.SIMILAR;
/**
* Project: groot-data-bank
@@ -95,6 +93,9 @@ import static com.shuwen.groot.common.base.Preconditions.checkNotNull;
@Service
public class LibraryServiceImpl implements ILibraryService {
@ConfigNotifyValue("config.new_vector_graph")
private List<String> newVectorGraphs;
@Resource
private LibraryTextSearchHandler libraryTextSearchHandler;
@@ -410,6 +411,11 @@ public class LibraryServiceImpl implements ILibraryService {
check(request);
EmbeddingRetrievalInfo retrievalInfo = embeddingRetrievalInfo(request.getRequestId(), request.getQuestion(), request.getDatasetIdList(), request.getGraph());
// 判断是否走新的向量检索逻辑
if (CollUtil.isNotEmpty(newVectorGraphs) && newVectorGraphs.contains(request.getGraph())) {
return newEmbeddingRetrieval(request, retrievalInfo);
}
// 召回
List<EmbeddingRetrievalResult> results = libraryVectorSearchHandler.search(retrievalInfo, request.isNeedBaseLibrary(), request.getSize());
@@ -448,6 +454,65 @@ public class LibraryServiceImpl implements ILibraryService {
return deduplication(sortedResults);
}
public List<EmbeddingRetrievalResult> newEmbeddingRetrieval(EmbeddingRetrievalLibraryRequest request, EmbeddingRetrievalInfo retrievalInfo) {
// 召回
List<EmbeddingRetrievalResult> results = libraryVectorSearchHandler.newSearch(request, retrievalInfo);
// 排序
rank(request.getQuestion(), results);
RecallLog.log(request.getRequestId(), request.getQuestion(), results);
// 获取最大向量分结果
EmbeddingRetrievalResult maxVectorScoreResult = maxVectorStoreResult(results);
// 去重,因为后续会对结果进行截取,所以必须去重
List<EmbeddingRetrievalResult> deduplicationResults = newDeduplication(results);
// 根据实体信息过滤结果
List<EmbeddingRetrievalResult> filterResults = postFilter(deduplicationResults, request.getNer());
// 排序后截取指定数量结果
List<EmbeddingRetrievalResult> sortedResults = filterResults.stream()
.sorted(new EmbeddingRetrievalResult.SortComparator())
.limit(request.getSize())
.collect(Collectors.toList());
// 若不存在最大向量分结果,则直接返回
if (maxVectorScoreResult == null) {
return sortedResults;
}
// 拼接排序后和最大向量分结果
boolean hasMaxVectorScoreResult = sortedResults.stream()
.anyMatch(result -> result.getId().equals(maxVectorScoreResult.getId()));
if (!hasMaxVectorScoreResult) {
sortedResults.add(maxVectorScoreResult);
}
// 去重
return newDeduplication(sortedResults);
}
private List<EmbeddingRetrievalResult> newDeduplication(List<EmbeddingRetrievalResult> originalResults) {
// 去重
DuplicateService duplicateService = DuplicateService.get();
List<EmbeddingRetrievalResult> dupResults = Lists.newArrayList();
for (EmbeddingRetrievalResult originalResult : originalResults) {
ArticleVO obj = new ArticleVO().setTitle(originalResult.getTitle()).setContent(originalResult.getContent());
boolean isAdd = true;
for (EmbeddingRetrievalResult dupResult : dupResults) {
ArticleVO other = new ArticleVO().setTitle(dupResult.getTitle()).setContent(dupResult.getContent());
// 是否重复
CompareResultVO compareResultVO = duplicateService.duplicate(obj, other);
if (SIMILAR.value().equals(compareResultVO.getSuggestion())) {
isAdd = false;
break;
}
}
if (isAdd) {
// 添加不重复的结果
dupResults.add(originalResult);
}
}
return dupResults;
}
@Override
public Object agentRetrieval(AgentRetrievalLibraryRequest request) {
check(request);

View File

@@ -1,5 +1,7 @@
package com.shuwen.groot.service.processor;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.amazonaws.services.s3.model.S3Object;
@@ -24,8 +26,11 @@ import com.shuwen.groot.manager.dto.TreeNode;
import com.shuwen.groot.manager.graph.GraphProcessManager;
import com.shuwen.groot.manager.gss.GssHandler;
import com.shuwen.groot.manager.media.MediaHandler;
import com.shuwen.groot.service.dto.tdc.GVectorParam;
import com.shuwen.groot.service.dto.tdc.TaskRes;
import com.shuwen.groot.service.handler.EmbeddingHandler;
import com.shuwen.groot.service.handler.SegmentHandler;
import com.shuwen.groot.service.handler.TdcHandler;
import com.shuwen.groot.service.handler.TextLLMHandler;
import com.shuwen.groot.service.handler.search.LibraryTextSearchHandler;
import com.shuwen.groot.service.handler.search.LibraryVectorSearchHandler;
@@ -101,6 +106,8 @@ public abstract class LibraryTaskProcessor {
@Resource
private DingTalkNotifier dingTalkNotifier;
@Resource
private TdcHandler tdcHandler;
public abstract void process(LibraryImportTask task);
@@ -296,6 +303,53 @@ public abstract class LibraryTaskProcessor {
libraryTextSearchHandler.add(LibraryLevel.dataset, datasetLibrary, "immediate");
}
protected void writeNewVector(LibraryImportTask task, DatasetLibrary datasetLibrary, List<SectionLibrary> sectionLibraryList) {
if (!task.isStructured() || !task.isVectorization()) {
return;
}
// 构建任务资源对象列表
TaskRes taskRes = buildTaskRes(task, datasetLibrary, sectionLibraryList);
if (!taskRes.hasRes()) {
log.info("dataset res is empty, datasetId={}", task.getDatasetId());
return;
}
// 向tdc发送向量化任务
GVectorParam param = new GVectorParam()
.setResList(Collections.singletonList(taskRes))
.setTaskParam(BeanUtil.copyProperties(task, GVectorParam.GVectorTaskParam.class));
tdcHandler.grootVector(param);
}
private TaskRes buildTaskRes(LibraryImportTask task, DatasetLibrary datasetLibrary, List<SectionLibrary> sectionLibraryList) {
TaskRes taskRes = new TaskRes()
.setKey(task.getDatasetId())
.setName(task.getName());
String data = null, url = null;
if (sectionLibraryList.isEmpty()) {
if (task.getFileType() == LibraryFileType.txt) {
// 纯文本结果
data = get(task.getUrl());
} else {
// 分页结果
List<PageLibrary> pageLibraryList = libraryTextSearchHandler.getPageList(task.getDatasetId());
if (CollUtil.isNotEmpty(pageLibraryList)) {
// 聚合内容
data = mergePageContent(pageLibraryList);
}
}
} else {
// markdown文件需要标题目录结构所以返回url交由tdc进行分段处理
if (task.getFileType() == LibraryFileType.markdown) {
url = gssHandler.sign(task.getUrl(), null);
} else {
// 分段结果
data = sectionLibraryList.stream().map(SectionLibrary::getContent).collect(Collectors.joining());
}
}
taskRes.setData(data).setUrl(url);
return taskRes;
}
protected void writeVector(LibraryImportTask task, DatasetLibrary datasetLibrary, List<SectionLibrary> sectionLibraryList) {
// 根据章节信息,获取待生成向量数据
List<WaitingGenerateVectorData> waitingList = processVector(task, datasetLibrary, sectionLibraryList);
@@ -661,6 +715,14 @@ public abstract class LibraryTaskProcessor {
}
private List<WaitingGenerateVectorData> splitPageParagraph(List<PageLibrary> pageLibraryList) {
String mergeContent = mergePageContent(pageLibraryList);
if (mergeContent.isEmpty()) {
return Lists.newArrayList();
}
return splitParagraph(mergeContent, null, null);
}
private String mergePageContent(List<PageLibrary> pageLibraryList) {
pageLibraryList.sort(Comparator.comparing(PageLibrary::getPage));
StringBuilder builder = new StringBuilder();
int curPage = -1;
@@ -678,10 +740,7 @@ public abstract class LibraryTaskProcessor {
}
curPage = library.getPage();
}
if (builder.length() == 0) {
return Lists.newArrayList();
}
return splitParagraph(builder.toString(), null, null);
return builder.toString();
}
private List<WaitingGenerateVectorData> splitParagraph(String data, List<String> sectionIdList, String title) {

View File

@@ -1,5 +1,6 @@
package com.shuwen.groot.service.processor.core;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.shuwen.groot.api.dto.library.DatasetLibrary;
@@ -14,6 +15,7 @@ import com.shuwen.groot.dao.entity.LibraryImportTask;
import com.shuwen.groot.manager.collected.CollectedDataManager;
import com.shuwen.groot.service.processor.LibraryIndexContext;
import com.shuwen.groot.service.processor.LibraryTaskProcessor;
import com.shuwen.mid.sdk.tools.config.ConfigNotifyValue;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -36,6 +38,9 @@ import static com.shuwen.groot.common.base.Preconditions.checkNotNull;
@Component
public class ParsedLibraryTaskProcessor extends LibraryTaskProcessor {
@ConfigNotifyValue("config.new_vector_graph")
private List<String> newVectorGraphs;
@Resource
private CollectedDataManager collectedDataManager;
@@ -114,6 +119,11 @@ public class ParsedLibraryTaskProcessor extends LibraryTaskProcessor {
List<SectionLibrary> sectionLibraryList = libraryTextSearchHandler.getSectionList(task.getDatasetId());
log.info("get waiting section library vector, dataset id: {}, size: {}", task.getDatasetId(), sectionLibraryList.size());
writeVector(task, datasetLibrary, sectionLibraryList);
// 判断是否构建新的向量索引
if (CollUtil.isNotEmpty(newVectorGraphs) && newVectorGraphs.contains(task.getGraph())) {
writeNewVector(task, datasetLibrary, sectionLibraryList);
}
}
public Pair<Boolean, JSONObject> checkResultDone(LibraryImportTask task) {

View File

@@ -1,5 +1,6 @@
package com.shuwen.groot;
import com.shuwen.mid.sdk.tools.config.EnableShamanAkskConfig;
import com.shuwen.ops.shaman.configmap.ShamanPropertySourceFactory;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@@ -14,6 +15,7 @@ import org.springframework.scheduling.annotation.EnableScheduling;
*
* @author Kenn
*/
@EnableShamanAkskConfig
@EnableCaching
@EnableScheduling
@SpringBootApplication

View File

@@ -29,6 +29,8 @@ spring.restful.graph-process=http://test.kg.general.process.xinhuazhiyun.com
spring.restful.smart-crop=http://test.groot.smartcrop.shuwen.com
spring.restful.fusion=http://test.kg.general.process.xinhuazhiyun.com
spring.restful.graph-data=http://test.groot-facade.xinhuazhiyun.com
spring.restful.tdc=http://test.tdc.shuwen.com
spring.restful.groot-vector=http://test.groot-vector.shuwen.com
dubbo.registry.address=116.62.230.200:2181,116.62.225.137:2181,116.62.227.195:2181
dubbo.service.search.common.version=2.0

View File

@@ -83,6 +83,7 @@
<caffeine.version>2.9.3</caffeine.version>
<mockito.version>4.8.0</mockito.version>
<mid-sdk.version>1.0.2-SNAPSHOT</mid-sdk.version>
<shuwen.configmap.version>1.1.1</shuwen.configmap.version>
<shuwen.search.version>1.1.12</shuwen.search.version>
<shuwen.lexical-segment.version>0.0.4-SNAPSHOT</shuwen.lexical-segment.version>
@@ -174,6 +175,11 @@
<artifactId>magic-common</artifactId>
<version>1.2.2-RELEASE</version>
</dependency>
<dependency>
<groupId>com.shuwen.mid</groupId>
<artifactId>mid-sdk</artifactId>
<version>${mid-sdk.version}</version>
</dependency>
<!-- xhzy end -->
<!-- aliyun start -->