讯飞星火大模型,通过websocket方式通信传递协议要求的报文,然后将流式返回的报文拼接为完整的响应内容,status=2时是最后一条消息。因为是websocket方式所以是异步响应的,如果想要同步需要使用CountDownLatch控制下线程等待最后一条消息处理完再继续往下走。星火不同版本稍微有一些差异,具体以官网提供的demo为准。
https://console.xfyun.cn/services/bm3
点应用名称进去查看详情
https://www.xfyun.cn/doc/spark/Web.html
下面仅仅是个示例,具体代码要根据官网最新文档确定。注意domain在1.5,2,3版本分别传general,generalv2,generalv3,传错了会报错10404。另外如果想同步返回结果,需要自己使用CountDownLatch控制主线程等待一段时间 countDownLatch.await(30000, TimeUnit.MILLISECONDS); 这样会在指定的最大超时时间内等待。
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.gson.Gson;
import okhttp3.*;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.*;
/**
* 带历史会话的示例
*/
public class BigModelNew extends WebSocketListener {
// 地址与鉴权信息 https://spark-api.xf-yun.com/v1.1/chat 1.5地址 domain参数为general
// 地址与鉴权信息 https://spark-api.xf-yun.com/v2.1/chat 2.0地址 domain参数为generalv2
public static final String hostUrl = "https://spark-api.xf-yun.com/v1/chat";
public static final String appid = "1234";
public static final String apiSecret = "xxxxx";
public static final String apiKey = "xxxxxxxxxx";
public static final String domain = "general"; //1.5,2,3版本分别传general,generalv2,generalv3,传错了会报错10404
public static List<RoleContent> historyList=new ArrayList<>(); // 对话历史存储集合
public static String totalAnswer=""; // 大模型的答案汇总
// 环境治理的重要性 环保 人口老龄化 我爱我的祖国
public static String NewQuestion = "";
public static final Gson gson = new Gson();
// 个性化参数
private String userId;
private Boolean wsCloseFlag;
private static Boolean totalFlag=true; // 控制提示用户是否输入
// 构造函数
public BigModelNew(String userId, Boolean wsCloseFlag) {
this.userId = userId;
this.wsCloseFlag = wsCloseFlag;
}
// 主函数
public static void main(String[] args) throws Exception {
// 个性化参数入口,如果是并发使用,可以在这里模拟
while (true){
if(totalFlag){
Scanner scanner=new Scanner(System.in);
System.out.print("我:");
totalFlag=false;
NewQuestion=scanner.nextLine();
// 构建鉴权url
String authUrl = getAuthUrl(hostUrl, apiKey, apiSecret);
OkHttpClient client = new OkHttpClient.Builder().build();
String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://");
Request request = new Request.Builder().url(url).build();
for (int i = 0; i < 1; i++) {
totalAnswer="";
WebSocket webSocket = client.newWebSocket(request, new BigModelNew(i + "",
false));
}
}else{
Thread.sleep(200);
}
}
}
public static boolean canAddHistory(){ // 由于历史记录最大上线1.2W左右,需要判断是能能加入历史
int history_length=0;
for(RoleContent temp:historyList){
history_length=history_length+temp.content.length();
}
if(history_length>12000){
historyList.remove(0);
historyList.remove(1);
historyList.remove(2);
historyList.remove(3);
historyList.remove(4);
return false;
}else{
return true;
}
}
// 线程来发送音频与参数
class MyThread extends Thread {
private WebSocket webSocket;
public MyThread(WebSocket webSocket) {
this.webSocket = webSocket;
}
public void run() {
try {
JSONObject requestJson=new JSONObject();
JSONObject header=new JSONObject(); // header参数
header.put("app_id",appid);
header.put("uid",UUID.randomUUID().toString().substring(0, 10));
JSONObject parameter=new JSONObject(); // parameter参数
JSONObject chat=new JSONObject();
chat.put("domain", domain);
chat.put("temperature",0.5);
chat.put("max_tokens",4096);
parameter.put("chat",chat);
JSONObject payload=new JSONObject(); // payload参数
JSONObject message=new JSONObject();
JSONArray text=new JSONArray();
// 历史问题获取
if(historyList.size()>0){
for(RoleContent tempRoleContent:historyList){
text.add(JSON.toJSON(tempRoleContent));
}
}
// 最新问题
RoleContent roleContent=new RoleContent();
roleContent.role="user";
roleContent.content=NewQuestion;
text.add(JSON.toJSON(roleContent));
historyList.add(roleContent);
message.put("text",text);
payload.put("message",message);
requestJson.put("header",header);
requestJson.put("parameter",parameter);
requestJson.put("payload",payload);
// System.err.println(requestJson); // 可以打印看每次的传参明细
webSocket.send(requestJson.toString());
// 等待服务端返回完毕后关闭
while (true) {
// System.err.println(wsCloseFlag + "---");
Thread.sleep(200);
if (wsCloseFlag) {
break;
}
}
webSocket.close(1000, "");
} catch (Exception e) {
e.printStackTrace();
}
}
}
@Override
public void onOpen(WebSocket webSocket, Response response) {
super.onOpen(webSocket, response);
System.out.print("大模型:");
MyThread myThread = new MyThread(webSocket);
myThread.start();
}
@Override
public void onMessage(WebSocket webSocket, String text) {
// System.out.println(userId + "用来区分那个用户的结果" + text);
JsonParse myJsonParse = gson.fromJson(text, JsonParse.class);
if (myJsonParse.header.code != 0) {
System.out.println("发生错误,错误码为:" + myJsonParse.header.code);
System.out.println("本次请求的sid为:" + myJsonParse.header.sid);
webSocket.close(1000, "");
}
List<Text> textList = myJsonParse.payload.choices.text;
for (Text temp : textList) {
System.out.print(temp.content);
totalAnswer=totalAnswer+temp.content;
}
if (myJsonParse.header.status == 2) {
// 可以关闭连接,释放资源
System.out.println();
System.out.println("*************************************************************************************");
if(canAddHistory()){
RoleContent roleContent=new RoleContent();
roleContent.setRole("assistant");
roleContent.setContent(totalAnswer);
historyList.add(roleContent);
}else{
historyList.remove(0);
RoleContent roleContent=new RoleContent();
roleContent.setRole("assistant");
roleContent.setContent(totalAnswer);
historyList.add(roleContent);
}
wsCloseFlag = true;
totalFlag=true;
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
super.onFailure(webSocket, t, response);
try {
if (null != response) {
int code = response.code();
System.out.println("onFailure code:" + code);
System.out.println("onFailure body:" + response.body().string());
if (101 != code) {
System.out.println("connection failed");
System.exit(0);
}
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
// 鉴权方法
public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
URL url = new URL(hostUrl);
// 时间
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
// 拼接
String preStr = "host: " + url.getHost() + "\n" +
"date: " + date + "\n" +
"GET " + url.getPath() + " HTTP/1.1";
// System.err.println(preStr);
// SHA256加密
Mac mac = Mac.getInstance("hmacsha256");
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
mac.init(spec);
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
// Base64加密
String sha = Base64.getEncoder().encodeToString(hexDigits);
// System.err.println(sha);
// 拼接
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
// 拼接地址
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();
// System.err.println(httpUrl.toString());
return httpUrl.toString();
}
//返回的json结果拆解
class JsonParse {
Header header;
Payload payload;
}
class Header {
int code;
int status;
String sid;
}
class Payload {
Choices choices;
}
class Choices {
List<Text> text;
}
class Text {
String role;
String content;
}
class RoleContent{
String role;
String content;
public String getRole() {
return role;
}
public void setRole(String role) {
this.role = role;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
}
}
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>big_model</artifactId>
<version>1.0-SNAPSHOT</version>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<java.version>1.8</java.version>
</properties>
<dependencies>
<!-- https://mvnrepository.com/artifact/com.alibaba/fastjson -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.67</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.google.code.gson/gson -->
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.java-websocket/Java-WebSocket -->
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>1.3.8</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp -->
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.10.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.squareup.okio/okio -->
<dependency>
<groupId>com.squareup.okio</groupId>
<artifactId>okio</artifactId>
<version>2.10.0</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>RELEASE</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.25</version>
</dependency>
</dependencies>
</project>
# Global logging configuration
log4j.rootLogger=DEBUG, stdout
# Console output...
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n