Hive UDAF开发详解

说明

这篇文章是来自 Hadoop Hive UDAF Tutorial - Extending Hive with Aggregation Functions:的不严格翻译,由于翻译的文章示例写得比较通俗易懂,此外,我把本身对于Hive的UDAF理解穿插到文章里面。

udfa是hive中用户自定义的汇集函数,hive内置UDAF函数包括有sum()与count(),UDAF实现有简单与通用两种方式,简单UDAF由于使用Java反射致使性能损失,并且有些特性不能使用,已经被弃用了;在这篇博文中咱们将关注Hive中自定义聚类函数-GenericUDAF,UDAF开发主要涉及到如下两个抽象类:html

org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator

源码连接

博文中的全部的代码和数据能够在如下连接找到: hive examples

示例数据准备

首先先建立一张包含示例数据的表:people,该表只有name一列,该列中包含了一个或多个名字,该表数据保存在people.txt文件中。java

~$ cat ./people.txt

John Smith
John and Ann White
Ted Green
Dorothy

把该文件上载到hdfs目录/user/matthew/people中:git

hadoop fs -mkdir people
hadoop fs -put ./people.txt people

下面要建立hive外部表,在hive shell中执行github

CREATE EXTERNAL TABLE people (name string)
ROW FORMAT DELIMITED FIELDS 
	TERMINATED BY '\t' 
	ESCAPED BY '' 
	LINES TERMINATED BY '\n'
STORED AS TEXTFILE 
LOCATION '/user/matthew/people';

相关抽象类介绍

建立一个GenericUDAF必须先了解如下两个抽象类:
org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator

为了更好理解上述抽象类的API,要记住hive只是mapreduce函数,只不过hive已经帮助咱们写好并隐藏mapreduce,向上提供简洁的sql函数,因此咱们要结合Mapper、Combiner与Reducer来帮助咱们理解这个函数。要记住在hadoop集群中有若干台机器,在不一样的机器上Mapper与Reducer任务独立运行。sql

因此大致上来讲,这个UDAF函数读取数据(mapper),汇集一堆mapper输出到部分汇集结果(combiner),而且最终建立一个最终的汇集结果(reducer)。由于咱们跨域多个combiner进行汇集,因此咱们须要保存部分汇集结果。shell

AbstractGenericUDAFResolver

Resolver很简单,要覆盖实现下面方法,该方法会根据sql传人的参数数据格式指定调用哪一个Evaluator进行处理。apache

<span style="background-color: rgb(255, 255, 255);"><span style="font-size:14px;">public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException;</span></span>

GenericUDAFEvaluator

UDAF逻辑处理主要发生在Evaluator中,要实现该抽象类的几个方法。api

在理解Evaluator以前,必须先理解objectInspector接口与GenericUDAFEvaluator中的内部类Model。跨域

ObjectInspector

做用主要是解耦数据使用与数据格式,使得数据流在输入输出端切换不一样的输入输出格式,不一样的Operator上使用不一样的格式。能够参考这两篇文章:first post on Hive UDFsHive中ObjectInspector的做用,里面有关于objectinspector的介绍。

Model

Model表明了UDAF在mapreduce的各个阶段。app

public static enum Mode {
    /**
     * PARTIAL1: 这个是mapreduce的map阶段:从原始数据到部分数据聚合
     * 将会调用iterate()和terminatePartial()
     */
    PARTIAL1,
        /**
     * PARTIAL2: 这个是mapreduce的map端的Combiner阶段,负责在map端合并map的数据::从部分数据聚合到部分数据聚合:
     * 将会调用merge() 和 terminatePartial() 
     */
    PARTIAL2,
        /**
     * FINAL: mapreduce的reduce阶段:从部分数据的聚合到彻底聚合 
     * 将会调用merge()和terminate()
     */
    FINAL,
        /**
     * COMPLETE: 若是出现了这个阶段,表示mapreduce只有map,没有reduce,因此map端就直接出结果了:从原始数据直接到彻底聚合
      * 将会调用 iterate()和terminate()
     */
    COMPLETE
  };

通常状况下,完整的UDAF逻辑是一个mapreduce过程,若是有mapper和reducer,就会经历PARTIAL1(mapper),FINAL(reducer),若是还有combiner,那就会经历PARTIAL1(mapper),PARTIAL2(combiner),FINAL(reducer)。

而有一些状况下的mapreduce,只有mapper,而没有reducer,因此就会只有COMPLETE阶段,这个阶段直接输入原始数据,出结果。

GenericUDAFEvaluator的方法

// 肯定各个阶段输入输出参数的数据格式ObjectInspectors
public  ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;

// 保存数据汇集结果的类
abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;

// 重置汇集结果
public void reset(AggregationBuffer agg) throws HiveException;

// map阶段,迭代处理输入sql传过来的列数据
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;

// map与combiner结束返回结果,获得部分数据汇集结果
public Object terminatePartial(AggregationBuffer agg) throws HiveException;

// combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果。
public void merge(AggregationBuffer agg, Object partial) throws HiveException;

// reducer阶段,输出最终结果
public Object terminate(AggregationBuffer agg) throws HiveException;

图解Model与Evaluator关系


Model各阶段对应Evaluator方法调用



Evaluator各个阶段下处理mapreduce流程

实例

下面将讲述一个汇集函数UDAF的实例,咱们将计算people这张表中的name列字母的个数。

下面的函数代码是计算指定列中字符的总数(包括空格)

代码

@Description(name = "letters", value = "_FUNC_(expr) - 返回该列中全部字符串的字符总数")
public class TotalNumOfLettersGenericUDAF extends AbstractGenericUDAFResolver {

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if (parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length - 1,
                    "Exactly one argument is expected.");
        }
        
        ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
        
        if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE){
            throw new UDFArgumentTypeException(0,
                            "Argument must be PRIMITIVE, but "
                            + oi.getCategory().name()
                            + " was passed.");
        }
        
        PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
        
        if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
            throw new UDFArgumentTypeException(0,
                            "Argument must be String, but "
                            + inputOI.getPrimitiveCategory().name()
                            + " was passed.");
        }
        
        return new TotalNumOfLettersEvaluator();
    }

    public static class TotalNumOfLettersEvaluator extends GenericUDAFEvaluator {

        PrimitiveObjectInspector inputOI;
        ObjectInspector outputOI;
        PrimitiveObjectInspector integerOI;
        
        int total = 0;

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
        	
            assert (parameters.length == 1);
            super.init(m, parameters);
           
             //map阶段读取sql列,输入为String基础数据格式
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
			//其他阶段,输入为Integer基础数据格式
            	integerOI = (PrimitiveObjectInspector) parameters[0];
            }

             // 指定各个阶段输出数据格式都为Integer类型
            outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
                    ObjectInspectorOptions.JAVA);
            return outputOI;

        }

        /**
         * 存储当前字符总数的类
         */
        static class LetterSumAgg implements AggregationBuffer {
            int sum = 0;
            void add(int num){
            	sum += num;
            }
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            LetterSumAgg result = new LetterSumAgg();
            return result;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
        	LetterSumAgg myagg = new LetterSumAgg();
        }
        
        private boolean warned = false;

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            if (parameters[0] != null) {
                LetterSumAgg myagg = (LetterSumAgg) agg;
                Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
                myagg.add(String.valueOf(p1).length());
            }
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = (LetterSumAgg) agg;
            total += myagg.sum;
            return total;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial != null) {
                
                LetterSumAgg myagg1 = (LetterSumAgg) agg;
                
                Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
                
                LetterSumAgg myagg2 = new LetterSumAgg();
                
                myagg2.add(partialSum);
                myagg1.add(myagg2.sum);
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = (LetterSumAgg) agg;
            total = myagg.sum;
            return myagg.sum;
        }

    }
}

代码说明

这里有一些关于combiner的资源,Philippe Adjiman 讲得不错。

AggregationBuffer 容许咱们保存中间结果,经过定义咱们的buffer,咱们能够处理任何格式的数据,在代码例子中字符总数保存在AggregationBuffer 。

/**
* 保存当前字符总数的类
*/
static class LetterSumAgg implements AggregationBuffer {
	int sum = 0;
	void add(int num){
		sum += num;
	}
}

这意味着UDAF在不一样的mapreduce阶段会接收到不一样的输入。Iterate读取咱们表中的一行(或者准确来讲是表),而后输出其余数据格式的汇集结果。

artialAggregation合并这些汇集结果到另外相同格式的新的汇集结果,而后最终的reducer取得这些汇集结果真后输出最终结果(该结果或许与接收数据的格式不一致)。

在init()方法中咱们指定输入为string,结果输出格式为integer,还有,部分汇集结果输出格式为integer(保存在aggregation buffer中);terminate()terminatePartial()二者输出一个integer

// init方法中根据不一样的mode指定输出数据的格式objectinspector
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
	inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
	integerOI = (PrimitiveObjectInspector) parameters[0];
}

// 不一样model阶段的输出数据格式
outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
                    ObjectInspectorOptions.JAVA);

iterate()函数读取到每行中列的字符串,计算与保存该字符串的长度

public void iterate(AggregationBuffer agg, Object[] parameters)
	throws HiveException {
	...
	Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
	myagg.add(String.valueOf(p1).length());
	}
}

Merge函数增长部分汇集总数到AggregationBuffer

public void merge(AggregationBuffer agg, Object partial)
      	throws HiveException {
	if (partial != null) {
                
		LetterSumAgg myagg1 = (LetterSumAgg) agg;
                
		Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
                
		LetterSumAgg myagg2 = new LetterSumAgg();
                
		myagg2.add(partialSum);
		myagg1.add(myagg2.sum);
	}
}

Terminate()函数返回AggregationBuffer中的内容,这里产生了最终结果。

public Object terminate(AggregationBuffer agg) throws HiveException {
	LetterSumAgg myagg = (LetterSumAgg) agg;
	total = myagg.sum;
	return myagg.sum;
}

使用自定义函数

ADD JAR ./hive-extension-examples-master/target/hive-extensions-1.0-SNAPSHOT-jar-with-dependencies.jar;
CREATE TEMPORARY FUNCTION letters as 'com.matthewrathbone.example.TotalNumOfLettersGenericUDAF';

SELECT letters(name) FROM people;
OK
44
Time taken: 20.688 seconds

资料参考