本篇我们将介绍下StreamGraph的生成过程
源码分析
以WordCount为例子
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
...中间过程省略
// 执行程序
env.execute("Streaming WordCount");
然后进到StreamExecutionEnvironment的execute方法
public abstract JobExecutionResult execute(String jobName) throws Exception;
发现这个方法是一个抽象方法,debug发现最终调用的是子类StreamContextEnvironment的execute方法,如下所示
public JobExecutionResult execute(String jobName) throws Exception {
Preconditions.checkNotNull("Streaming Job name should not be null.");
StreamGraph streamGraph = this.getStreamGraph();// 获取StreamGraph
streamGraph.setJobName(jobName);
transformations.clear();
// execute the programs
if (ctx instanceof DetachedEnvironment) {
LOG.warn("Job was executed in detached mode, the results will be available on completion.");
((DetachedEnvironment) ctx).setDetachedPlan(streamGraph);
return DetachedEnvironment.DetachedJobExecutionResult.INSTANCE;
} else {
return ctx
.getClient()
.run(streamGraph, ctx.getJars(), ctx.getClasspaths(), ctx.getUserCodeClassLoader(), ctx.getSavepointRestoreSettings())
.getJobExecutionResult();
}
}
核心方法就是 StreamGraph streamGraph = this.getStreamGraph(); 这一行,继续往下看
@Internal
public StreamGraph getStreamGraph() {
if (transformations.size() <= 0) {
throw new IllegalStateException("No operators defined in streaming topology. Cannot execute.");
}
return StreamGraphGenerator.generate(this, transformations);//生成 StreamGraph
}
在这里调用了StreamGraphGenerator.generate方法,并将transformations传入,这里的transformations变量是一个List<StreamTransformation<?>>,那么StreamTransformation又是什么东西呢,会发现他是一个抽象类,并有很多的实现类。具体实现如下图:
那一个具体的transform实现类,是怎么加入到transformations集合中的呢,我们以DataStream.map为例看一下添加过程。
public <R> SingleOutputStreamOperator<R> map(MapFunction<T, R> mapper) {
TypeInformation<R> outType = TypeExtractor.getMapReturnTypes(clean(mapper), getType(),
Utils.getCallLocationName(), true);
return transform("Map", outType, new StreamMap<>(clean(mapper)));//接着调用了transform方法
}
public <R> SingleOutputStreamOperator<R> transform(String operatorName, TypeInformation<R> outTypeInfo, OneInputStreamOperator<T, R> operator) {
// read the output type of the input Transform to coax out errors about MissingTypeInfo
transformation.getOutputType();
OneInputTransformation<T, R> resultTransform = new OneInputTransformation<>(
this.transformation,
operatorName,
operator,
outTypeInfo,
environment.getParallelism());
@SuppressWarnings({ "unchecked", "rawtypes" })
SingleOutputStreamOperator<R> returnStream = new SingleOutputStreamOperator(environment, resultTransform);
getExecutionEnvironment().addOperator(resultTransform);//加入到了transformations集合中
return returnStream;
}
上面两个方法可以看到调用map方法后,会再调用transform方法,mapFunction生成一个OneInputOperator,OneInputOperator会生成一个OneInputTransformation,OneInputTransformation会再生成一个SingleOutputStreamOperator,SingleOutputStreamOperator实际是DataStream的一个子类,也就是说再调用了map函数之后,返回的还是一个DataStream,然后可以继续调用filter等方法,这也就是Flink中的链式调用。
再DataStream到DataStream的转换过程中,生成了一个OneInputTransformation,并执行了
getExecutionEnvironment().addOperator(resultTransform)
public void addOperator(StreamTransformation<?> transformation) {
Preconditions.checkNotNull(transformation, "transformation must not be null.");
this.transformations.add(transformation);
}
原来是DataStream再转换的过程中将各个operator加入到了transformations集合中了,我们接着return StreamGraphGenerator.generate(this, transformations) 继续往下研究
public static StreamGraph generate(StreamExecutionEnvironment env, List<StreamTransformation<?>> transformations) {
return new StreamGraphGenerator(env).generateInternal(transformations);
}
private StreamGraph generateInternal(List<StreamTransformation<?>> transformations) {
for (StreamTransformation<?> transformation: transformations) {
transform(transformation);//核心代码
}
return streamGraph;
}
核心代码就在transform(transformation)方法中,详细看下
private Collection<Integer> transform(StreamTransformation<?> transform) {
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
LOG.debug("Transforming " + transform);
if (transform.getMaxParallelism() <= 0) {
// if the max parallelism hasn't been set, then first use the job wide max parallelism
// from theExecutionConfig.
int globalMaxParallelismFromConfig = env.getConfig().getMaxParallelism();
if (globalMaxParallelismFromConfig > 0) {
transform.setMaxParallelism(globalMaxParallelismFromConfig);
}
}
// call at least once to trigger exceptions about MissingTypeInfo
transform.getOutputType();
Collection<Integer> transformedIds;
if (transform instanceof OneInputTransformation<?, ?>) {
transformedIds = transformOneInputTransform((OneInputTransformation<?, ?>) transform);
} else if (transform instanceof TwoInputTransformation<?, ?, ?>) {
transformedIds = transformTwoInputTransform((TwoInputTransformation<?, ?, ?>) transform);
} else if (transform instanceof SourceTransformation<?>) {
transformedIds = transformSource((SourceTransformation<?>) transform);
} else if (transform instanceof SinkTransformation<?>) {
transformedIds = transformSink((SinkTransformation<?>) transform);
} else if (transform instanceof UnionTransformation<?>) {
transformedIds = transformUnion((UnionTransformation<?>) transform);
} else if (transform instanceof SplitTransformation<?>) {
transformedIds = transformSplit((SplitTransformation<?>) transform);
} else if (transform instanceof SelectTransformation<?>) {
transformedIds = transformSelect((SelectTransformation<?>) transform);
} else if (transform instanceof FeedbackTransformation<?>) {
transformedIds = transformFeedback((FeedbackTransformation<?>) transform);
} else if (transform instanceof CoFeedbackTransformation<?>) {
transformedIds = transformCoFeedback((CoFeedbackTransformation<?>) transform);
} else if (transform instanceof PartitionTransformation<?>) {
transformedIds = transformPartition((PartitionTransformation<?>) transform);
} else if (transform instanceof SideOutputTransformation<?>) {
transformedIds = transformSideOutput((SideOutputTransformation<?>) transform);
} else {
throw new IllegalStateException("Unknown transformation: " + transform);
}
// need this check because the iterate transformation adds itself before
// transforming the feedback edges
if (!alreadyTransformed.containsKey(transform)) {
alreadyTransformed.put(transform, transformedIds);
}
if (transform.getBufferTimeout() >= 0) {
streamGraph.setBufferTimeout(transform.getId(), transform.getBufferTimeout());
}
if (transform.getUid() != null) {
streamGraph.setTransformationUID(transform.getId(), transform.getUid());
}
if (transform.getUserProvidedNodeHash() != null) {
streamGraph.setTransformationUserHash(transform.getId(), transform.getUserProvidedNodeHash());
}
if (transform.getMinResources() != null && transform.getPreferredResources() != null) {
streamGraph.setResources(transform.getId(), transform.getMinResources(), transform.getPreferredResources());
}
return transformedIds;
}
这里会根据transform的类型调用t不同的ransformXXX方法,我们挑一个OneInputTransformation类型,具体看下transformOneInputTransform是怎么实现的
private <IN, OUT> Collection<Integer> transformOneInputTransform(OneInputTransformation<IN, OUT> transform) {
Collection<Integer> inputIds = transform(transform.getInput());//递归调用该transform的上游进行转换
// the recursive call might have already transformed this
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
String slotSharingGroup = determineSlotSharingGroup(transform.getSlotSharingGroup(), inputIds);
streamGraph.addOperator(transform.getId(),
slotSharingGroup,
transform.getCoLocationGroupKey(),
transform.getOperator(),
transform.getInputType(),
transform.getOutputType(),
transform.getName());//添加StreamNode
if (transform.getStateKeySelector() != null) {
TypeSerializer<?> keySerializer = transform.getStateKeyType().createSerializer(env.getConfig());
streamGraph.setOneInputStateKey(transform.getId(), transform.getStateKeySelector(), keySerializer);
}
streamGraph.setParallelism(transform.getId(), transform.getParallelism());
streamGraph.setMaxParallelism(transform.getId(), transform.getMaxParallelism());
for (Integer inputId: inputIds) {
streamGraph.addEdge(inputId, transform.getId(), 0);//添加edge边
}
return Collections.singleton(transform.getId());
}
然后上面方法会首先递归调用自己的上游transform,然后添加StreamNode节点和StreamEdge边,看下streamGraph.addOperator方法
public <IN, OUT> void addOperator(
Integer vertexID,
String slotSharingGroup,
@Nullable String coLocationGroup,
StreamOperator<OUT> operatorObject,
TypeInformation<IN> inTypeInfo,
TypeInformation<OUT> outTypeInfo,
String operatorName) {
if (operatorObject instanceof StoppableStreamSource) {//添加node
addNode(vertexID, slotSharingGroup, coLocationGroup, StoppableSourceStreamTask.class, operatorObject, operatorName);
} else if (operatorObject instanceof StreamSource) {
addNode(vertexID, slotSharingGroup, coLocationGroup, SourceStreamTask.class, operatorObject, operatorName);
} else {
addNode(vertexID, slotSharingGroup, coLocationGroup, OneInputStreamTask.class, operatorObject, operatorName);
}
TypeSerializer<IN> inSerializer = inTypeInfo != null && !(inTypeInfo instanceof MissingTypeInfo) ? inTypeInfo.createSerializer(executionConfig) : null;
TypeSerializer<OUT> outSerializer = outTypeInfo != null && !(outTypeInfo instanceof MissingTypeInfo) ? outTypeInfo.createSerializer(executionConfig) : null;
setSerializers(vertexID, inSerializer, null, outSerializer);
if (operatorObject instanceof OutputTypeConfigurable && outTypeInfo != null) {
@SuppressWarnings("unchecked")
OutputTypeConfigurable<OUT> outputTypeConfigurable = (OutputTypeConfigurable<OUT>) operatorObject;
// sets the output type which must be know at StreamGraph creation time
outputTypeConfigurable.setOutputType(outTypeInfo, executionConfig);
}
if (operatorObject instanceof InputTypeConfigurable) {
InputTypeConfigurable inputTypeConfigurable = (InputTypeConfigurable) operatorObject;
inputTypeConfigurable.setInputType(inTypeInfo, executionConfig);
}
if (LOG.isDebugEnabled()) {
LOG.debug("Vertex: {}", vertexID);
}
}
//在addNode方法添加StreamNode
protected StreamNode addNode(Integer vertexID,
String slotSharingGroup,
@Nullable String coLocationGroup,
Class<? extends AbstractInvokable> vertexClass,
StreamOperator<?> operatorObject,
String operatorName) {
if (streamNodes.containsKey(vertexID)) {
throw new RuntimeException("Duplicate vertexID " + vertexID);
}
StreamNode vertex = new StreamNode(environment,
vertexID,
slotSharingGroup,
coLocationGroup,
operatorObject,
operatorName,
new ArrayList<OutputSelector<?>>(),
vertexClass);
streamNodes.put(vertexID, vertex);
return vertex;
}
具体的转换原理总结:List<StreamTransformation<?>>集合中的transform每一个transform会由他们自己内部的input属性,类似一个指针,指向父transform。这样就会将所有的transform串起来。接着每个transform会转化为StreamNode,连接的指针转换为StreamEdge。StreamEdge内部包含一个StreamPartitioner,相当于是StreamNode之间的数据分区分发类型,具体的有以下几类。broadCast Forward shuffle rebalance等