Aggregates are unevaluable expressions and cannot have eval and doGenCode method.
Basic requirement would be to use user defined java objects as internal spark aggregation buffer type.
And, passing extra arguments to aggregates e.g aggregate(col, 0.24)
Spark provides TypedImperativeAggregate[T] contract for such requirement (imperative as in expressed in terms of imperative initialize, update, and merge methods).
TypedImperativeAggregate[T] abstract class
caseTestAggregation(child:Expression)extendsTypedImperativeAggregate[T]{// Check input typesoverridedefcheckInputDataTypes():TypeCheckResult// Initialize ToverridedefcreateAggregationBuffer():T// Update T with rowoverridedefupdate(buffer:T,inputRow:InternalRow):T// Merge Intermediate buffers onto first bufferoverridedefmerge(buffer:T,other:T):T// Final valueoverridedefeval(buffer:T):AnyoverridedefwithNewMutableAggBufferOffset(newOffset:Int):TestAggregationoverridedefwithNewInputAggBufferOffset(newOffset:Int):TestAggregationoverridedefchildren:Seq[Expression]overridedefnullable:Boolean// Datatype of outputoverridedefdataType:DataTypeoverridedefprettyName:Stringoverridedefserialize(obj:T):Array[Byte]overridedefdeserialize(bytes:Array[Byte]):T}
Example
case class Average holds count and sum of elements and also acts as internal aggregate buffer.
Aggregate takes in a numeric column and an extra argument n and return avg(column) * n.
Spark alchemy's NativeFunctionRegistration is used to register functions to spark.
Aggregate Code :
importcom.swoop.alchemy.spark.expressions.NativeFunctionRegistrationimportorg.apache.spark.sql.SparkSessionimportorg.apache.spark.{SparkConf,SparkContext}importorg.apache.spark.sql.catalyst.InternalRowimportorg.apache.spark.sql.catalyst.analysis.TypeCheckResultimportorg.apache.spark.sql.catalyst.expressions._importorg.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregateimportorg.apache.spark.sql.types._importjava.io.{ByteArrayInputStream,ByteArrayOutputStream,ObjectInputStream,ObjectOutputStream}caseclassAverage(varsum:Long,varcount:Long)caseclassAvgTest(child:Expression,nExpression:Expression,overridevalmutableAggBufferOffset:Int=0,overridevalinputAggBufferOffset:Int=0)extendsTypedImperativeAggregate[Average]{// private lazy val n: Long = nExpression.eval().asInstanceOf[Long]defthis(child:Expression)=this(child,Literal(1),0,0)defthis(child:Expression,nExpression:Expression)=this(child,nExpression,0,0)overridedefcheckInputDataTypes():TypeCheckResult={child.dataTypematch{caseLongType=>TypeCheckResult.TypeCheckSuccesscase_=>TypeCheckResult.TypeCheckFailure(s"$prettyName only supports long input")}}overridedefcreateAggregationBuffer():Average={newAverage(0,0)}overridedefupdate(buffer:Average,inputRow:InternalRow):Average={valvalue=child.eval(inputRow)buffer.sum+=value.asInstanceOf[Long]buffer.count+=1buffer}overridedefmerge(buffer:Average,other:Average):Average={buffer.sum+=other.sumbuffer.count+=other.countbuffer}overridedefeval(buffer:Average):Any={valn:Int=nExpression.eval().asInstanceOf[Int]((buffer.sum*n)/(buffer.count))}overridedefwithNewMutableAggBufferOffset(newOffset:Int):AvgTest=copy(mutableAggBufferOffset=newOffset)overridedefwithNewInputAggBufferOffset(newOffset:Int):AvgTest=copy(inputAggBufferOffset=newOffset)overridedefchildren:Seq[Expression]=Seq(child,nExpression)overridedefnullable:Boolean=true// The result type is the same as the input type.overridedefdataType:DataType=child.dataTypeoverridedefprettyName:String="avg_test"overridedefserialize(obj:Average):Array[Byte]={valstream:ByteArrayOutputStream=newByteArrayOutputStream()valoos=newObjectOutputStream(stream)oos.writeObject(obj)oos.close()stream.toByteArray}overridedefdeserialize(bytes:Array[Byte]):Average={valois=newObjectInputStream(newByteArrayInputStream(bytes))valvalue=ois.readObjectois.close()value.asInstanceOf[Average]}}
Driver code :
objectTestAgg{objectBegRegisterextendsNativeFunctionRegistration{valexpressions:Map[String, (ExpressionInfo, FunctionBuilder)]=Map(expression[AvgTest]("multiply_average"))}defmain(args:Array[String]):Unit={valconf=newSparkConf().setMaster("local[*]").setAppName("FirstDemo")valsc=newSparkContext(conf)valspark=SparkSession.builder().appName("Demo").config(conf).getOrCreate()BegRegister.registerFunctions(spark)valdf=spark.read.json("src/test/resources/employees.json")df.createOrReplaceTempView("employees")df.show()/*
+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+
*/valresult=spark.sql("SELECT multiply_average(salary) as average_salary FROM employees")result.show()/*
+--------------+
|average_salary|
+--------------+
| 3750|
+--------------+
*/valresult1=spark.sql("SELECT multiply_average(salary, 2) as average_salary FROM employees")result1.show()/*
+--------------+
|average_salary|
+--------------+
| 7500|
+--------------+
*/valresult2=spark.sql("SELECT multiply_average(salary, 3) as average_salary FROM employees")result2.show()/*
+--------------+
|average_salary|
+--------------+
| 11250|
+--------------+
*/}}
Here, nExpression represents our n argument. Other lines are self-explanatory.
Top comments (0)
Subscribe
For further actions, you may consider blocking this person and/or reporting abuse
Top comments (0)