presto sql輸入表、輸入字段、limit、join操作解析

前言

一段時間沒有寫文章了,寫下最近做的事情。目前我們這邊有一個metabase 查詢平臺供運營、分析師、產品等人員使用,我們的查詢都是使用 presto 引擎。並且我們的大數據組件都使用的是 emr 組件,並且涉及到中國、美西、美東、印度、歐洲、西歐等多個區域,表的權限管理就特別困難。所以就需要一個統一的權限管理來維護某些人擁有那些表的權限,避免隱私的數據泄漏。於是我們就需要一款sql解析工具來解析 presto sql 的輸入表。另外還有一點,由於使用的人較多,資源較少,爲了避免長查詢,我們還會對含有 join 操作查詢、 select * 的查詢直接拒絕

sql 解析

第一種方法

presto 本身也是用的 antlr 進行 sql 語法的編輯,如果你clone了presto的源碼,會在 presto-parse 模塊中發現 presto/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 文件,也就是說我們可以通過直接使用該文件生成解析的配置文件1,然後進行 sql 解析
,但是這種方法太過複雜,我嘗試了下放棄了,因爲從語法樹中獲取某些值時比較混亂,容錯較小,還需要再遍歷其兒子、兄弟節點,並且通過節點的 getText 方法獲得節點值。
在這裏插入圖片描述

第二種方法

我們肯定很容易的就想到,presto 源碼肯定也對 sql 進行了解析,何不直接使用 presto 的解析類呢?
功夫不負有心人,我在源碼中發現了 SqlParser 這個類,該類在 presto-parser 模塊中,通過調用 createStatement(String sql) 方法會返回一個Statement 2,後面我們只需要對 Statement 進行遍歷即可

去掉註釋

在 sql執行之前,我們需要進行一些預操作,比如去掉註釋,分號分割多行代碼

   /**
     * 替換sql註釋
     *
     * @param sqlText sql
     * @return 替換後的sl
     */
    protected String replaceNotes(String sqlText) {
        StringBuilder newSql = new StringBuilder();
        String lineBreak = "\n";
        String empty = "";
        String trimLine;
        for (String line : sqlText.split(lineBreak)) {
            trimLine = line.trim();
            if (!trimLine.startsWith("--") && !trimLine.startsWith("download")) {
                //過濾掉行內註釋
                line = line.replaceAll("/\\*.*\\*/", empty);
                if (org.apache.commons.lang3.StringUtils.isNotBlank(line)) {
                    newSql.append(line).append(lineBreak);
                }
            }
        }
        return newSql.toString();
    }

分號分割多段 sql


    /**
     * ;分割多段sql
     *
     * @param sqlText sql
     * @return
     */
    protected ArrayList<String> splitSql(String sqlText) {
        String[] sqlArray = sqlText.split(Constants.SEMICOLON);
        ArrayList<String> newSqlArray = new ArrayList<>(sqlArray.length);
        String command = "";
        int arrayLen = sqlArray.length;
        String oneCmd;
        for (int i = 0; i < arrayLen; i++) {
            oneCmd = sqlArray[i];
            boolean keepSemicolon = (oneCmd.endsWith("'") && i + 1 < arrayLen && sqlArray[i + 1].startsWith("'"))
                    || (oneCmd.endsWith("\"") && i + 1 < arrayLen && sqlArray[i + 1].startsWith("\""));
            if (oneCmd.endsWith("\\")) {
                command += org.apache.commons.lang.StringUtils.chop(oneCmd) + Constants.SEMICOLON;
                continue;
            } else if (keepSemicolon) {
                command += oneCmd + Constants.SEMICOLON;
                continue;
            } else {
                command += oneCmd;
            }
            if (org.apache.commons.lang3.StringUtils.isBlank(command)) {
                continue;
            }
            newSqlArray.add(command);
            command = "";
        }
        return newSqlArray;
    }

sql解析

經過預處理之後,就需要對 sql 進行解析。inputTables、outputTables、tempTables 分別表示輸入表、輸出表、臨時表

 @Override
    protected Tuple3<HashSet<TableInfo>, HashSet<TableInfo>, HashSet<TableInfo>> parseInternal(String sqlText) throws SqlParseException {
        this.inputTables = new HashSet<>();
        this.outputTables = new HashSet<>();
        this.tempTables = new HashSet<>();
        try {
        	//ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL 表示數字以DECIMAL類型解析
            check(new SqlParser().createStatement(sqlText, new ParsingOptions(ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL)));
        } catch (ParsingException e) {
            throw new SqlParseException("parse sql exception:" + e.getMessage(), e);
        }
        return new Tuple3<>(inputTables, outputTables, tempTables);
    }
根節點識別

進入 check 方法進行 Statement 的遍歷

   /**
     * statement 過濾 只識別select 語句
     *
     * @param statement
     * @throws SqlParseException
     */
    private void check(Statement statement) throws SqlParseException {
    	//如果根節點是查詢節點 獲取所有的孩子節點,深度優先搜索遍歷
        if (statement instanceof Query) {
            Query query = (Query) statement;
            List<Node> children = query.getChildren();
            for (Node child : children) {
                checkNode(child);
            }
        } else if (statement instanceof Use) {
            Use use = (Use) statement;
            this.currentDb = use.getSchema().getValue();
        } else if (statement instanceof ShowColumns) {
            ShowColumns show = (ShowColumns) statement;
            String allName = show.getTable().toString().replace("hive.", "");
            inputTables.add(buildTableInfo(allName, OperatorType.READ));
        } else if (statement instanceof ShowTables) {
            ShowTables show = (ShowTables) statement;
            QualifiedName qualifiedName = show.getSchema().orElseThrow(() -> new SqlParseException("unkonw table name or db name" + statement.toString()));
            String allName = qualifiedName.toString().replace("hive.", "");
            if (allName.contains(Constants.POINT)) {
                allName += Constants.POINT + "*";
            }
            inputTables.add(buildTableInfo(allName, OperatorType.READ));

        } else {
            throw new SqlParseException("sorry,only support read statement,unSupport statement:" + statement.getClass().getName());
        }
    }
  • 如果根節點是 Query 查詢節點 獲取所有的孩子節點,深度優先搜索遍歷
  • 如果根節點是 Use 切換數據庫的節點,修改當前的數據庫名稱
  • 如果根節點是ShowColumns 查看錶字段的節點,將該表加入輸入表
  • 如果根節點是ShowTables 查看錶結構的節點,將該表加入輸入表
  • 否則拋出無法解析的異常

子節點遍歷

主要進入 checkNode 方法,進行查詢語句所有孩子節點的遍歷

/**
     * node 節點的遍歷
     *
     * @param node
     */
    private void checkNode(Node node) throws SqlParseException {
    	//查詢子句
        if (node instanceof QuerySpecification) {
            QuerySpecification query = (QuerySpecification) node;
            //如果查詢包含limit語句 直接將limit入棧
            query.getLimit().ifPresent(limit -> limitStack.push(limit));
            //遍歷子節點
            loopNode(query.getChildren());
        } else if (node instanceof TableSubquery) {
            loopNode(node.getChildren());
        } else if (node instanceof AliasedRelation) {
            // 表的別名 需要放到tableAliaMap供別別名的字段解析使用
            AliasedRelation alias = (AliasedRelation) node;
            String value = alias.getAlias().getValue();
            if (alias.getChildren().size() == 1 && alias.getChildren().get(0) instanceof Table) {
                Table table = (Table) alias.getChildren().get(0);
                tableAliaMap.put(value, table.getName().toString());
            } else {
                tempTables.add(buildTableInfo(value, OperatorType.READ));
            }
            loopNode(node.getChildren());
        } else if (node instanceof Query || node instanceof SubqueryExpression
                || node instanceof Union || node instanceof With
                || node instanceof LogicalBinaryExpression || node instanceof InPredicate) {
            loopNode(node.getChildren());

        } else if (node instanceof Join) {
        	//發現join操作  設置hasJoin 爲true
            hasJoin = true;
            loopNode(node.getChildren());
        }
        //基本都是where條件,過濾掉,如果需要,可以調用getColumn解析字段
        else if (node instanceof LikePredicate || node instanceof NotExpression
                || node instanceof IfExpression
                || node instanceof ComparisonExpression || node instanceof GroupBy
                || node instanceof OrderBy || node instanceof Identifier
                || node instanceof InListExpression || node instanceof DereferenceExpression
                || node instanceof IsNotNullPredicate || node instanceof IsNullPredicate
                || node instanceof FunctionCall) {
            print(node.getClass().getName());

        } else if (node instanceof WithQuery) {
        	//with 子句的臨時表 
            WithQuery withQuery = (WithQuery) node;
            tempTables.add(buildTableInfo(withQuery.getName().getValue(), OperatorType.READ));
            loopNode(withQuery.getChildren());
        } else if (node instanceof Table) {
        	//發現table節點 放入輸入表
            Table table = (Table) node;
            inputTables.add(buildTableInfo(table.getName().toString(), OperatorType.READ));
            loopNode(table.getChildren());
        } else if (node instanceof Select) {
        	//發現select 子句,需要調用getColumn方法從selectItems中獲取select的字段
            Select select = (Select) node;
            List<SelectItem> selectItems = select.getSelectItems();
            HashSet<String> columns = new HashSet<>();
            for (SelectItem item : selectItems) {
                if (item instanceof SingleColumn) {
                    columns.add(getColumn(((SingleColumn) item).getExpression()));
                } else if (item instanceof AllColumns) {
                    columns.add(item.toString());
                } else {
                    throw new SqlParseException("unknow column type:" + item.getClass().getName());
                }
            }
            //將字段入棧
            columnsStack.push(columns);

        } else {
            throw new SqlParseException("unknow node type:" + node.getClass().getName());
        }
    }

上面需要注意的是,每次想輸入表、臨時表中添加表時都對應一個 column的集合從 columnsStack 出棧。
後面看從 selectItems 中獲取字段的方法 getColumn.

  /**
     * select 字段表達式中獲取字段
     *
     * @param expression
     * @return
     */
    private String getColumn(Expression expression) throws SqlParseException {
        if (expression instanceof IfExpression) {
            IfExpression ifExpression = (IfExpression) expression;
            List<Expression> list = new ArrayList<>();
            list.add(ifExpression.getCondition());
            list.add(ifExpression.getTrueValue());
            ifExpression.getFalseValue().ifPresent(list::add);
            return getString(list);
        } else if (expression instanceof Identifier) {
            Identifier identifier = (Identifier) expression;
            return identifier.getValue();
        } else if (expression instanceof FunctionCall) {
            FunctionCall call = (FunctionCall) expression;
            StringBuilder columns = new StringBuilder();
            List<Expression> arguments = call.getArguments();
            int size = arguments.size();
            for (int i = 0; i < size; i++) {
                Expression exp = arguments.get(i);
                if (i == 0) {
                    columns.append(getColumn(exp));
                } else {
                    columns.append(getColumn(exp)).append(columnSplit);
                }
            }
            return columns.toString();
        } else if (expression instanceof ComparisonExpression) {
            ComparisonExpression compare = (ComparisonExpression) expression;
            return getString(compare.getLeft(), compare.getRight());
        } else if (expression instanceof Literal || expression instanceof ArithmeticUnaryExpression) {
            return "";
        } else if (expression instanceof Cast) {
            Cast cast = (Cast) expression;
            return getColumn(cast.getExpression());
        } else if (expression instanceof DereferenceExpression) {
            DereferenceExpression reference = (DereferenceExpression) expression;
            return reference.toString();
        } else if (expression instanceof ArithmeticBinaryExpression) {
            ArithmeticBinaryExpression binaryExpression = (ArithmeticBinaryExpression) expression;
            return getString(binaryExpression.getLeft(), binaryExpression.getRight());
        } else if (expression instanceof SearchedCaseExpression) {
            SearchedCaseExpression caseExpression = (SearchedCaseExpression) expression;
            List<Expression> exps = caseExpression.getWhenClauses().stream().map(whenClause -> (Expression) whenClause).collect(Collectors.toList());
            caseExpression.getDefaultValue().ifPresent(exps::add);
            return getString(exps);
        } else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            return getString(whenClause.getOperand(), whenClause.getResult());
        } else if (expression instanceof LikePredicate) {
            LikePredicate likePredicate = (LikePredicate) expression;
            return likePredicate.getValue().toString();
        } else if (expression instanceof InPredicate) {
            InPredicate predicate = (InPredicate) expression;
            return predicate.getValue().toString();
        } else if (expression instanceof SubscriptExpression) {
            SubscriptExpression subscriptExpression = (SubscriptExpression) expression;
            return getColumn(subscriptExpression.getBase());
        } else if (expression instanceof LogicalBinaryExpression) {
            LogicalBinaryExpression logicExp = (LogicalBinaryExpression) expression;
            return getString(logicExp.getLeft(), logicExp.getRight());
        } else if (expression instanceof IsNullPredicate) {
            IsNullPredicate isNullExp = (IsNullPredicate) expression;
            return getColumn(isNullExp.getValue());
        } else if (expression instanceof IsNotNullPredicate) {
            IsNotNullPredicate notNull = (IsNotNullPredicate) expression;
            return getColumn(notNull.getValue());
        } else if (expression instanceof CoalesceExpression) {
            CoalesceExpression coalesce = (CoalesceExpression) expression;
            return getString(coalesce.getOperands());
        }
        throw new SqlParseException("無法識別的表達式:" + expression.getClass().getName());
        //   return expression.toString();
    }

由於我們 select 的字段可能包含很多種函數,所以需要一一進行解析,就不在細說。

後續

其實我也實現了 spark sql、hive sql 的輸入表、輸出表的解析,代碼放在了github 上 :https://github.com/scxwhite/parseX
分享不易,請不要吝嗇你的star


  1. 生成配置文件的方式可以通過 idea 安裝 antlr 插件,對 sqlBase.g4文件進行配置後生成 java 等語言的解析類 。 ↩︎

  2. Statement 可以理解爲對語法樹 node 節點的一層封裝,方便於我們的解析 ↩︎

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章