java 調用python 模型



import org.apache.commons.lang3.StringUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;

public class MatchTensor {
    public static String filePath="D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\vocab.txt";
    public static String labelPath="D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\label.txt";
    public static HashMap<String,Integer> hashMap =new HashMap<String,Integer>(){
        {
            put("0",1693);
        }
    };
    public static Boolean is_control(char char_){
      if (("\t".equals(char_)) |("\n".equals(char_))|("\r".equals(char_))){
          return false;
        }
        Character c=new Character(char_);
        int type=c.getType(char_);
        if (type==15 | type==16 |type==18|type==19|type==7){
            return true;
        }
        return false;

    }
    public static Boolean is_whitespace(char char_){
        Character newchar=new Character(char_);
      if(" ".equals(newchar)|"\t".equals(newchar)|"\n".equals(newchar)|"\r".equals(newchar)){
          return true;
      }
      if (newchar.getType(newchar)==12){
          return true;
      }
      return false;
    }
    public static String cleanText(String str){
        String result="";
        for (int i = 0; i < str.length(); i++){
            int chr1 = (char) str.charAt(i);
            if((chr1==0)|(chr1==0xfffd)|is_control(str.charAt(i))){
                  continue;
            }
            if(is_whitespace(str.charAt(i))){
                 result+=" ";
            }else{
                result+=str.charAt(i);
            }

        }
        return result;
    }
    private static String run_strip_accents(String text){
         return "";
    }
    private static Boolean _is_punctuation(char char_){
         int num= (int)char_ ;
        if ((num >= 33 && num <= 47) | (num >= 58 && num <= 64) |
                (num >= 91 && num <= 96) | (num >= 123 && num <= 126)){
            return true;
        }
        Character c=new Character(char_);
        int t=c.getType(char_);
        if((t>=20 && t<=24)|(t>=29 && t<=30)){
            return true ;
        }
         return  false;
    }
    private static  ArrayList<String> _run_split_on_punc(String text){
        char[] chars=text.toCharArray();
        ArrayList<ArrayList> output=new ArrayList();
        //String[] output=new String[500];
        int i=0;
        Boolean start_new_word =true;
        while(i<chars.length){
           char char_=chars[i];
           if(_is_punctuation(char_)){
               ArrayList arrayList=new ArrayList();
               arrayList.add(char_);
               output.add(arrayList);
               start_new_word =true;
           }else{
               if(start_new_word){
                   ArrayList arrayList=new ArrayList();
                   output.add(arrayList);
               }
               start_new_word=false;
               output.get(output.size()-1).add(char_);
           }
           i+=1;
        }
        ArrayList<String> result=new ArrayList<>();
        for(ArrayList arrayList:output){
            String ss="";
            for(int m = 0;m < arrayList.size(); m++){
                ss+=arrayList.get(m);
            }
            result.add(ss);
        }
        return result;
    }
    public static ArrayList<String> whitespace_tokenize(String text){
        ArrayList<String> outtexts=new ArrayList();
        if(StringUtils.isNotEmpty(text)){
            String[] texts=text.split("\\s+");// Java 以空格分割字符串
            for(String i:texts){
                outtexts.add(i);
            }
        }
        return outtexts;
    }
    public static ArrayList<String> basicTokenize(String text){
       String newText=cleanText(text);
       ArrayList<String> outtexts= whitespace_tokenize(newText);
       ArrayList<String> split_tokens=new ArrayList<>();
       for(String s:outtexts){
           String news=s.toLowerCase();
           //String newtoken=run_strip_accents(news);
           ArrayList<String> newss=_run_split_on_punc(news);
           for(String ss:newss){
               split_tokens.add(ss);
           }
       }
       String all="";
       for(String s:split_tokens){
            all+=s+" ";
       }
       return whitespace_tokenize(all.substring(0,all.length()-1));
    }
    public static HashMap<String,Integer> getVocab(String filePath) throws IOException {
        HashMap<String,Integer> hashMap=new HashMap<>();
        FileInputStream fileInputStream=new FileInputStream(filePath);
        InputStreamReader inputStreamReader=new InputStreamReader(fileInputStream,"UTF-8");
        BufferedReader br=new BufferedReader(inputStreamReader);
        String line="";
        int i=0;
        while((line=br.readLine())!=null){ // line 是否有換行符
            hashMap.put(line,i);
            i+=1;
        }

        return hashMap;

    }
    public static String index2label(int index) throws IOException {
        HashMap<Integer,String> hashMap=new HashMap<>();
        FileInputStream fileInputStream=new FileInputStream(labelPath);
        InputStreamReader inputStreamReader=new InputStreamReader(fileInputStream,"UTF-8");
        BufferedReader br=new BufferedReader(inputStreamReader);
        String line="";
        int i=0;
        while((line=br.readLine())!=null){ // line 是否有換行符
            hashMap.put(i,line);
            i+=1;
        }
        return hashMap.get(index);

    }
    public static ArrayList<String> wordpiece_tokenizer(String text) throws IOException{
        HashMap<String,Integer> hashMap=getVocab(filePath);
          ArrayList<String> output_tokens=new ArrayList<>();
         for(String s:whitespace_tokenize(text)){
             char[] chars = s.toCharArray();
             if(chars.length>100){
                 output_tokens.add("[UNK]");
                 continue;
             }
             boolean is_bad=false;
             int start=0;
             ArrayList<String> sub_tokens=new ArrayList<>();
             while(start<chars.length){
                 int end=chars.length;
                 String cur_subsrt=null;
                 while(start<end){
                     String substr="";
                     for(int i=start;i<end;i++){
                         substr+=chars[i];
                     }
                     if(start>0){
                         substr="##"+substr;
                     }
                     if( hashMap.containsKey(substr)){
                         cur_subsrt=substr;
                         break;
                     }
                     end-=1;
                 }
                 if(cur_subsrt.isEmpty()){
                     is_bad=true;
                      break;
                 }
                 sub_tokens.add(cur_subsrt);
                 start=end;
             }
             if(is_bad){
                 output_tokens.add("[UNK]");
             }else{
                 for(String tokens:sub_tokens){
                     output_tokens.add(tokens);
                 }

             }
         }
         return output_tokens;

    }
    public static ArrayList<String> tokenize(String text) throws IOException {
        ArrayList<String> split_tokens=new ArrayList<>();
        ArrayList<String> list=basicTokenize(text);
     for(String s:list){
         for(String tokens:wordpiece_tokenizer(s)){
             split_tokens.add(tokens);
         }
     }
     return split_tokens;
    }
    public static InputFeature constructTensor(String data) throws IOException {

        HashMap<String,Integer> hashMap= getVocab(filePath);
        InputExample example=new InputExample();
        example.setGuid("1");
        example.setText_a("........");
        example.setText_b("");
        example.setLabel("0");
        ArrayList<String> tokens_a = tokenize(example.getText_a());
        ArrayList<String> tokens_anew=new ArrayList<>();
        ArrayList<String> tokens=new ArrayList<>();
        ArrayList<Integer> segment_ids=new ArrayList<>();
        tokens.add("[CLS]");
        segment_ids.add(0);
        if(tokens_a.size()>23){
            for(int i=0;i<tokens_a.size()-2;i++){
                tokens_anew.add(tokens_a.get(i));
            }
        }else{
            tokens_anew=tokens_a;
        }
        for(String s:tokens_anew){
            tokens.add(s);
            segment_ids.add(0);
        }
        tokens.add("[SEP]");
        segment_ids.add(0);
        ArrayList<Integer> input_ids=new ArrayList<>();
        ArrayList<Integer> input_mask=new ArrayList<>();
        for(String s:tokens){
            input_ids.add(hashMap.get(s));
            input_mask.add(1);
        }
        while(input_ids.size()<25){
            input_ids.add(0);
            input_mask.add(0);
            segment_ids.add(0);
        }
        InputFeature feature=new InputFeature();
        feature.setInput_ids(input_ids);
        feature.setInput_mask(input_mask);
        feature.setSegments_ids(segment_ids);
        feature.setLabel_id(1693);
        return feature;
    }
    public static Session readGraph() throws IOException {
            String modelDir = ".";
        byte[] graphDef = Files.readAllBytes(Paths.get("D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\graph.db"));
        Graph g = new Graph();
        g.importGraphDef(graphDef);
        Session session= new Session(g);
        return session;
    }
    public static void main(String[] args) throws  Exception{

        Date start = new Date();
        Date sess=new Date();
        Session session = readGraph();
        Date t1=new Date();
        System.out.println(t1.getTime()-start.getTime());
        System.out.println("..........t1");
        InputFeature input = constructTensor("........");
        Date t2=new Date();
        System.out.println(t2.getTime()-t1.getTime());
        System.out.println("..........t2");
            ArrayList<Integer> input_ids1 = input.getInput_ids();
            int[] inputs_ids=new int[25];
            for(int i=0;i<input_ids1.size();i++){
                inputs_ids[i]=input_ids1.get(i);
            }
            //int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
            ArrayList<Integer> input_mask1 = input.getInput_mask();
            int[] input_mask=new int[25];
            for(int i=0;i<input_mask1.size();i++){
                input_mask[i]=input_mask1.get(i);
            }


            ArrayList<Integer> segments_ids1 = input.getSegments_ids();
            int[] segments_ids=new int[25];
            for(int i=0;i<segments_ids.length;i++){
                segments_ids[i]=segments_ids1.get(i);
            }
        Date t3=new Date();
        System.out.println(t3.getTime()-t2.getTime());
        System.out.println("..........t3");
            Tensor result = session.runner().feed("inputs_id",  Tensor.create(inputs_ids)).feed("input_mask",Tensor.create(input_mask)).feed("token_type_ids",Tensor.create(segments_ids)).fetch("loss/probabilities").run().get(0);
        Date t4=new Date();
        System.out.println(t4.getTime()-t3.getTime());
        System.out.println("..........t4");
            long[] rshape = result.shape();
//        System.out.println(result); // FLOAT tensor with shape [1, 1693]
            float[][] prop= (float[][]) result.copyTo(new float[1][1693]);
            HashMap<Float,Integer> map=new HashMap();
            float[] t=prop[0];
            for(int i=0;i<prop[0].length;i++) {
                map.put(prop[0][i],i); //將值和下標存入Map
            }
            Arrays.sort(t);
            HashMap<String,List> result1=new HashMap<>();
            List list=new ArrayList();
            for(int i=t.length-1;i>=t.length-3;i--){
                HashMap<String,Object> hashMap=new HashMap<String,Object>();
                int label=map.get(t[i]);
                System.out.println(".................................");
                hashMap.put("weight",t[i]);
                System.out.println(t[i]);
                hashMap.put("text",index2label(label));
                System.out.println(label);
                System.out.println(index2label(label));
                list.add(hashMap);
            }
            Date end=new Date();
            long diff=end.getTime()-start.getTime();
            System.out.println(start.getTime());
            System.out.println(diff);
            result1.put("result",list);
        Date t5=new Date();
        System.out.println(t5.getTime()-t4.getTime());
        System.out.println("..........t5");
        }



}







如果僅僅看 java調用python 模型 推薦這個 https://blog.csdn.net/rabbit_judy/article/details/80054085#commentBox

這個主要是記錄一下 在java中 對數據採用 bert的數據處理方式 然後 調用python訓練的bert模型 ,效果很好,但是速度慢些  

python 中的 ord(‘中’) 函數等價於 java 中的 ( int)('中')

python中的unicodedata.category('中') 等價於 Character c=new Character('中')  c.getTypeI('中')

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章