全民AI的时代已经在趋势之中,各类应用层出不穷,而想要构建一个完善的AI应用/系统,底层存储是不可缺少的一个组件。
与传统数据库或大数据存储不同的是,这种场景下则需要选择向量数据库,是专门用来存储和查询向量的数据库,其存储的向量来自于对文本、语音、图像、视频等的向量化数据,向量数据库不仅能够完成基本的CRUD(添加、读取查询、更新、删除)等操作,还能够对向量数据进行更快速的相似性搜索。
Milvus是众多向量库中的之一,适用于多个场景,如Questions & Answering系统、推荐系统等,单节点 Milvus 可以在秒内完成十亿级的向量搜索,分布式架构亦能满足用户的水平扩展需求。
参考文档:
Milvus官网
为AI而生的数据库:Milvus详解及实战
不论是问答还是推荐,它们对上层暴露的接口仅仅是predict(...)
/search(...)
/query(...)
,模式是相同的,因此可以共用一个基本的Schema,固定基本的字段即可。
public class MilvusMeta {
@Getter
private final RecommenderSchema defaultMetricsSchema;
@Getter
private final QASchema defaultQASchema;
public interface Schema {
String getCollectionName();
CreateCollectionParam getCreateCollectionParam();
CreateIndexParam getCreateIndexParam();
}
@AllArgsConstructor
public abstract static class BasicSchema implements Schema {
public static final String SEARCH_PARAM = "{\"nprobe\":10}"; // Params
public static final String INDEX_PARAM = "{\"nlist\":1024}"; // ExtraParam
public static final IndexType INDEX_TYPE_DEFAULT = IndexType.IVF_FLAT;
public static final MetricType METRIC_TYPE_DEFAULT = MetricType.L2; // metric type
public static final String INDEX_NAME_DEFAULT = "ivf_flat";
public static final String PRIMARY_KEY_FIELD_NAME_DEFAULT = "id";
public static final String PARTITION_KEY_FIELD_NAME_DEFAULT = "public";
public static final String FIELD_NAME_DEFAULT = "embeddings";
@Getter
protected final CreateCollectionParam createCollectionParam;
@Getter
protected final CreateIndexParam createIndexParam;
@Override
public String getCollectionName() {
return createCollectionParam.getCollectionName();
}
}
/**
* This schema is designed for storing Zen metrics.
* TODO: Add more fields/features to describe a metric.
*/
public static class RecommenderSchema extends BasicSchema {
private RecommenderSchema(CreateCollectionParam collectionParam, CreateIndexParam indexParam) {
super(collectionParam, indexParam);
}
public static RecommenderSchema create(ZenAiConfig.Storages.MilvusConf conf) {
ZenAiConfig.Storages.Collection collection = conf.getActiveRecommenderCollection();
return new RecommenderSchema(
defaultCollectionParam(collection, collection.getEmbeddingsDimension()),
MilvusUtil.createIndexParam(collection));
}
private static CreateCollectionParam defaultCollectionParam(ZenAiConfig.Storages.Collection collection,
int dimension) {
FieldType pkType = FieldType.newBuilder()
.withName(collection.getPrimaryKey())
.withDataType(DataType.VarChar)
.withPrimaryKey(true)
.withMaxLength(100)
.withAutoID(false)
.build();
// 被embedding的字段
FieldType embeddedFieldType = FieldType.newBuilder()
.withName(collection.getEmbeddedFieldName())
.withDataType(DataType.VarChar)
.withMaxLength(255)
.build();
// embedding vector字段
FieldType embeddingFieldType = FieldType.newBuilder()
.withName(collection.getFieldName())
.withDataType(DataType.FloatVector)
.withDimension(dimension)
.build();
// 指定分区键字段,每一个Collection都需要指定一个分区键,除了能够Hive/Spark那样切分数据外,还能够加速相似查询。
// 虽然Milvus支持多种方案以切分数据,但从管理复杂度、查询效率上来看,一个Collection对应多个数据分区,是最佳的方案。
FieldType partitionKeyType = FieldType.newBuilder()
.withName(collection.getPartitionKey())
.withPartitionKey(true)
.withDataType(DataType.VarChar)
.withMaxLength(100)
.build();
return CreateCollectionParam.newBuilder()
.withCollectionName(collection.getName())
.withDescription(collection.getDescription())
// .withShardsNum(2)
.addFieldType(pkType)
.addFieldType(embeddedFieldType)
.addFieldType(embeddingFieldType)
.addFieldType(partitionKeyType)
// 开启动态字段添加功能
.withEnableDynamicField(true)
.build();
}
}
public static class QASchema extends BasicSchema {
public static final String ANSWER_FIELD_NAME = "answer";
public static final String SCORE_FIELD_NAME = "score";
public static final float SCORE_MAX_DEFAULT = 5.0f;
public static final float SCORE_MIN_DEFAULT = 0.0f;
public static final String INTENTION_FIELD_NAME = "intention";
public static final String QUESTION_OCCURRENCE = "occurrence";
public QASchema(CreateCollectionParam createCollectionParam, CreateIndexParam createIndexParam) {
super(createCollectionParam, createIndexParam);
}
public static QASchema create(ZenAiConfig.Storages.MilvusConf conf) {
ZenAiConfig.Storages.Collection collection = conf.getActiveQACollection();
return new QASchema(
defaultCollectionParam(collection, collection.getEmbeddingsDimension()),
MilvusUtil.createIndexParam(collection));
}
private static CreateCollectionParam defaultCollectionParam(ZenAiConfig.Storages.Collection collection,
int dimension) {
FieldType pkType = FieldType.newBuilder()
.withName(collection.getPrimaryKey())
.withDataType(DataType.VarChar)
.withPrimaryKey(true)
.withMaxLength(100)
.withAutoID(false)
.build();
FieldType embeddedFieldType = FieldType.newBuilder()
.withName(collection.getEmbeddedFieldName())
.withDataType(DataType.VarChar)
.withMaxLength(65535)
.build();
FieldType embeddingFieldType = FieldType.newBuilder()
.withName(collection.getFieldName())
.withDataType(DataType.FloatVector)
.withDimension(dimension)
.build();
FieldType partitionKeyType = FieldType.newBuilder()
.withName(collection.getPartitionKey())
.withPartitionKey(true)
.withDataType(DataType.VarChar)
.withMaxLength(100)
.build();
return CreateCollectionParam.newBuilder()
.withCollectionName(collection.getName())
.withDescription(collection.getDescription())
// .withShardsNum(2)
.addFieldType(pkType)
.addFieldType(embeddedFieldType)
.addFieldType(embeddingFieldType)
.addFieldType(partitionKeyType)
.withEnableDynamicField(true) // enable to insert new fields without modifying the code
.build();
}
}
public interface IMilvusOperations {
ZenAiConfig.Storages.Collection getCollection();
MilvusConnection.MultiStatus delete(Filter filter);
MilvusConnection.MultiStatus create(MilvusMeta.Index index);
MilvusConnection.MultiStatus drop(String index);
MilvusConnection.MultiStatus insert(MilvusData.Dataset dataset);
MilvusConnection.MultiStatus insertAndFlush(MilvusData.Dataset dataset);
/**
* Query records by filter on the specified partition, which works like a normal SQL engine.
*
* @param partition which partition to query
* @param filter boolean expression obeys the rules of Milvus
* @param outputFields if empty, the result will contain all the fields, including the dynamic;
* otherwise the result only contains the specified fields.
* @return a nonnull instance, size of which is 0 if no matched records, otherwise is positive.
*/
MilvusData.BasicPredictData queryByPartition(String partition, Filter filter, List<String> outputFields);
List<MilvusData.BasicPredictData> search(List<List<Float>> vectors, Filter filter, int topK,
List<String> outputFields);
}
/**
* 每个系统可能有不同的embedding的实现,因此需要定义一个接口。
*/
public interface IEmbedding {
ImmutableList<List<Float>> getEmbeddings(List<String> messages);
}
/**
* 通用接口定义,供应用层使用,可以基于sentence返回Milvus相似性结果集。
*/
public interface INlpSystem extends IMilvusOperations, IDataset, IEmbedding {
default MilvusData.BasicPredictData predict(String sentence) {
return predict(sentence, Filter.TRUE);
}
default MilvusData.BasicPredictData predict(String sentence, Filter filter) {
return predict(sentence, filter, getCollection().getOutputFields());
}
default MilvusData.BasicPredictData predict(String sentence, Filter filter,
List<String> outputFields) {
ImmutableList<List<Float>> vectors = getEmbeddings(Lists.newArrayList(sentence));
if (vectors.isEmpty()) {
return MilvusData.BasicPredictData.EMPTY;
}
List<String> mergedOutputFields = Sets.union(
ImmutableSet.copyOf(outputFields),
ImmutableSet.copyOf(getCollection().getOutputFields()))
.immutableCopy().asList();
List<MilvusData.BasicPredictData> res = search(vectors, filter, getCollection().getTopk(), mergedOutputFields);
return res.isEmpty() ? MilvusData.BasicPredictData.EMPTY : res.get(0);
}
default MilvusData.BasicPredictData predictByPartition(String partition, String sentence) {
return predictByPartition(partition, sentence, Filter.TRUE, getCollection().getOutputFields());
}
default MilvusData.BasicPredictData predictByPartition(String partition, String sentence,
Filter filter, List<String> outputFields) {
ImmutableList<List<Float>> vectors = getEmbeddings(Lists.newArrayList(sentence));
if (vectors.isEmpty()) {
return MilvusData.BasicPredictData.EMPTY;
}
List<String> mergedOutputFields = Sets.union(
ImmutableSet.copyOf(outputFields),
ImmutableSet.copyOf(getCollection().getOutputFields()))
.immutableCopy().asList();
return searchByPartition(partition, vectors.get(0), filter, mergedOutputFields);
}
}
/**
* Q & A系统接口。
*/
public interface IQuestionAnswering extends INlpSystem {
}
/**
* 推荐系统接口。
*/
public interface IRecommender extends INlpSystem, ISyncer {
}
以列式格式
构建插入Milvus的数据集,需要注意的是,Milvus JAVA SDK 2.3.1版本
并不支持列式导致dynamic fields
,因此我对源码进行了改造,以支持列式插入动态字段。
这个问题,已经反馈给了社区,并且已经在v2.3.2版本中支持。
public interface MilvusData {
interface BasicData {
/**
* Return a list view of the splitted data, to avoid copy.
*/
BasicData[] split(int splitSize);
/**
* Return a view of the range [start, end) data, to avoid copy.
*/
BasicData subData(int groupId, int start, int end);
int size();
}
interface EmbeddingsProducer extends Function<List<String>, ImmutableList<List<Float>>> {
}
@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
class Dataset {
private List<BasicInsertData> inserts;
}
abstract class GroupedBasicData implements BasicData {
@Getter
private final int groupId;
@Getter
@Setter
@Accessors(chain = true)
private GroupedBasicData parent;
protected GroupedBasicData(int groupId) {
this.groupId = groupId;
}
/**
* Split the data into more more sub-dataset.
*
* @param groups the number of expected groups
* @return an array of sub-dataset views from the original dataset
*/
public abstract BasicData[] grouped(int groups);
/**
* 每一个切分或是extract的子数据集,都应该拥有一个可以唯一标识它的ID
*/
public String fullGroupId() {
if (parent == null) {
return String.valueOf(groupId);
}
return parent.fullGroupId() + "-" + groupId;
}
}
@Getter
abstract class PartitionedBasicData<T> extends GroupedBasicData {
// 每一个系统都需要指定一个分区键,因此为了能够最小化存储,这里使用一个变量
// 保存整个数据集应该插入
private final T partition;
protected PartitionedBasicData(T partition, int groupId) {
super(groupId);
this.partition = partition;
}
public abstract List<T> getPartitions();
}
/**
* 以列式的形式构建插入数据集,并完成数据导入到Milvus。
* Milvus
*/
class BasicInsertData extends PartitionedBasicData<String> {
private final String collection;
private final ImmutableList<String> ids;
private final ImmutableList<String> embeddingsInput;
private final Supplier<ImmutableList<List<Float>>> vectorsSupplier;
private final EmbeddingsProducer embeddingsProducer;
private ImmutableMap<String, List<?>> dynamicFields;
private final AtomicBoolean vectorsInitialized = new AtomicBoolean(false);
}
public interface MilvusData {
/**
* 一个通用的数据集,可以保存search/query的结果,行式数据结构。
*/
@Getter
class BasicPredictData extends GroupedBasicData {
@Getter
@Builder
@AllArgsConstructor
public static class Row {
@JsonProperty
private String id;
private String embeddingsInput;
@JsonProperty
private Map<String, Object> extensions;
@JsonProperty
private float distance;
@JsonIgnore
private List<Float> vector;
public <T> T getAs(String key, Class<T> clazz) {
return getAs(key, clazz, null);
}
public <T> T getAs(String key, Class<T> clazz, T defaultValue) {
return clazz.cast(extensions.getOrDefault(key, defaultValue));
}
}
}
}
这个代码示例展示了如何构建列式数据集,并将其插入Milvus的流程。
注意到这里特别演示了使用了多线程并行 插入的功能,其原因有二:
@Test
void testSyncCollections() throws ExecutionException, InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(2);
// 一组唯一值,用于区别每一条数据记录
ImmutableList<String> ids = ImmutableList.of("1", "2", "3", "4", "5");
// 一组指标名,这些指标就是待检索的合法指标集。
ImmutableList<String> metrics = ImmutableList.of("m1", "m2", "m3", "m4", "m5");
// 生成一组包含5个向量的列表,对应于每一个指标名
ImmutableList<List<Float>> vectors = generateVectors(5, TEST_COLLECTION_DIMENSION);
// 构建插入数据集
MilvusData.BasicInsertData data = new MilvusData.BasicInsertData(
config.getStorages().getMilvusConf().getActiveRecommenderCollection().getName(), ids, metrics,
//这里使用Java中的Provider接口,提供执行插入数据任务时,对指标名列表向量化,由于这里事先生成了向量数组,因此直接从索引构建数据
messages -> messages.stream().map(metrics::indexOf).map(vectors::get).collect(toImmutableList()));
ImmutableList<Double> randoms = ImmutableList.of(1.0d, 2.0d, 3.0d, 4.0d, 5.0d);
// 添加动态字段及相应的数据
data.updateDynamicFields(ImmutableMap.of("random", randoms));
// 构建并行插入数据任务
// 3 groups: ([1, 2]), ([3, 4]), ([5])
// 1 batche: ([1],[2]),([3],[4]),([5])
CompletableFuture<Integer>[] futures = milvusService.syncMetrics(data, 3, 1, executorService);
assertEquals(3, futures.length);
CompletableFuture.allOf(futures).join();
assertEquals(2, futures[0].get());
assertEquals(2, futures[1].get());
assertEquals(1, futures[2].get());
}
用户输入一个指标(Metric)名,或是包含指标名的语句,可以通过Milvus的Search接口,找到最相近的TOP K指标,前提是需要对输入指标名进行向量化,然后以此向量来r从Milvus库中既存的指标集中计算找到最相似的。
@Test
void testLookingForMetric() {
int topK = config..getMilvusConf().getActiveRecommenderCollection().getTopk();
Optional<MilvusData.BasicPredictData> metrics = milvusService.getActiveRecommenderSys().map(system -> system.predict("销售总额"));
assertTrue(metrics.isPresent());
assertEquals(metrics.get().getRows().size(), topK);
metrics = milvusService.getActiveRecommenderSys().map(system -> system.predict("销售总额", lt("random", 0f)));
assertFalse(metrics.isPresent());
}
用户输入一段描述,可以通过Milvus提供的Search接口,找到历史相关的问题,并返回与此问题相关的上下文,并辅助回答AI模型回答用户的当前问题。
@Test
void testSearchWithSimilarityOfMultiVectors() {
ImmutableList<String> testQuestions = ImmutableList.of("用柱形图展示2019年12月的总销售额", "用拆线图展示2020年12月的总净利润");
ImmutableList<String> testIds = testQuestions.stream()
.map(DefaultQuestionAnswering::encodeQuestion)
.collect(ImmutableList.toImmutableList());
IEmbedding embeddingSvc = aiService.getMilvusService().get().getActiveQuestionAnswering().get();
DefaultQuestionAnswering.QAInsertData insertData = system.getInsertDataBuilder()
.ids(testIds)
.questions(testQuestions)
.answers(ImmutableList.of("很好", "不错"))
// 用户对于此问题返回结果的评价
.scores(ImmutableList.of(1.0f, 1.0f))
// 定义embeddings生成器,在插入时才会计算embeddings
.embeddingsProducer(
questions -> {
ImmutableList<List<Float>> qvectors = embeddingSvc.getEmbeddings(questions);
ImmutableList<List<Float>> mvectors = embeddingSvc.getEmbeddings(ImmutableList.of("总销售额", "总净利润"));
return ImmutableList.of(merge(qvectors.get(0), mvectors.get(0)), merge(qvectors.get(1), mvectors.get(1)));
})
.build();
system.insert(new MilvusData.Dataset(ImmutableList.of(insertData)));
// Case 1:
// 用柱形图展示2019年12月的总销售额: 52.0181
// 用拆线图展示2020年12月的总净利润: 147.0664
verifySearch(system, "用拆线图展示2020年12月的总销售额", "总销售额", 0, "销售额", this::merge);
// Case 2:
// 用柱形图展示2019年12月的总销售额: 313.70105
// 用拆线图展示2020年12月的总净利润: 181.3783
verifySearch(system, "今年5月的净利润详情", "净利润", 0, "利润", this::merge);
// Case 3:
// 用柱形图展示2019年12月的总销售额: 160.30568
// 用拆线图展示2020年12月的总净利润: 357.7008
verifySearch(system, "今年5月的销售额详情", "销售额", 0, "销售额", this::merge);
}
Milvus对上层提供了与传统数据库相似的接口,以管理Milvus数据,同时提供了带有过滤功能的数据检索接口,使得上层应用能够很方便地利用传统数据库思维,来设计 和实现自己的系统。
但在使用中也感受到一些局限性或可能提升的点: