ctc_greedy_decoder

tf_export 通过 api_export 类来导出。
add_dispatch_support 将调度处理包装器添加到 TensorFlow Python API 的装饰器。

@tf_export("nn.ctc_greedy_decoder")
@dispatch.add_dispatch_support
def ctc_greedy_decoder(inputs,sequence_length,merge_repeated=True,blank_index=None):"""Performs greedy decoding on the logits given in input (best path).Given a tensor as `inputs`, the `blank_index` parameter defines the classindex of the blank symbol.For example:If `blank_index` is equal to 1:>>> inf = float("inf")>>> logits = tf.constant([[[   0., -inf, -inf],...                        [ -2.3, -inf, -0.1]],...                       [[ -inf, -0.5, -inf],...                        [ -inf, -inf, -0.1]],...                       [[ -inf, -inf, -inf],...                        [ -0.1, -inf, -2.3]]])>>> seq_lens = tf.constant([2, 3])>>> outputs = tf.nn.ctc_greedy_decoder(...     logits,...     seq_lens,...     blank_index=1)Notes:- Unlike `ctc_beam_search_decoder`, `ctc_greedy_decoder` considers blanksas regular elements when computing the probability of a sequence.- Default `blank_index` is `(num_classes - 1)`, unless overriden.If `merge_repeated` is `True`, merge repeated classes in output.This means that if consecutive logits' maximum indices are the same,only the first of these is emitted.  The sequence `A B B * B * B` (where '*'is the blank label) becomes* `A B B B` if `merge_repeated=True`.* `A B B B B` if `merge_repeated=False`.Args:inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.The logits.sequence_length: 1-D `int32` vector containing sequence lengths, having size`[batch_size]`.merge_repeated: Boolean.  Default: True.blank_index: (Optional). Default: `num_classes - 1`. Define the class indexto use for the blank label. Negative values will start from num_classes,ie, -1 will reproduce the ctc_greedy_decoder behavior of usingnum_classes - 1 for the blank symbol, which corresponds to the default.Returns:A tuple `(decoded, neg_sum_logits)` wheredecoded: A single-element list. `decoded[0]`is an `SparseTensor` containing the decoded outputs s.t.:`decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.The rows store: `[batch, time]`.`decoded.values`: Values vector, size `(total_decoded_outputs)`.The vector stores the decoded classes.`decoded.dense_shape`: Shape vector, size `(2)`.The shape values are: `[batch_size, max_decoded_length]`neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for thesequence found, the negative of the sum of the greatest logit at eachtimeframe."""

gen_ctc_ops.py文件由 tf_gen_op_wrapper_private_py 根据 tensorflow/python/BUILD 中的信息生成。C++ 驼峰格式的函数名会转换为 Python 的小写下划线形式。
CTCGreedyDecoderOp 对输入中给出的 logits 执行贪婪解码(最佳路径)。
返回一个 SparseTensor 列表和存储每个时间帧最大 logit 负数和的矩阵。

  outputs = gen_ctc_ops.ctc_greedy_decoder(inputs,sequence_length,merge_repeated=merge_repeated,blank_index=blank_index)(decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputsreturn ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,decoded_shape)], log_probabilities)

REGISTER_OP(“CTCGreedyDecoder”)

REGISTER_OP 注册算子。

REGISTER_OP("CTCGreedyDecoder").Input("inputs: T").Input("sequence_length: int32").Attr("merge_repeated: bool = false").Attr("blank_index: int = -1").Output("decoded_indices: int64").Output("decoded_values: int64").Output("decoded_shape: int64").Output("log_probability: T").Attr("T: {float, double} = DT_FLOAT").SetShapeFn([](InferenceContext* c) {ShapeHandle inputs;ShapeHandle sequence_length;TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));// Get batch size from inputs and sequence_length.DimensionHandle batch_size;TF_RETURN_IF_ERROR(c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));DimensionHandle total_decoded_outputs = c->UnknownDim();c->set_output(0, c->Matrix(total_decoded_outputs, 2));c->set_output(1, c->Vector(total_decoded_outputs));c->set_output(2, c->Vector(2));c->set_output(3, c->Matrix(batch_size, 1));return Status::OK();});

REGISTER_OP

通过 OpDefBuilderWrapper 来构造 OpDef。

#define REGISTER_OP_IMPL(ctr, name, is_system_op)                         \static ::tensorflow::InitOnStartupMarker const register_op##ctr         \TF_ATTRIBUTE_UNUSED =                                               \TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \<< ::tensorflow::register_op::OpDefBuilderWrapper(name)#define REGISTER_OP(name)        \TF_ATTRIBUTE_ANNOTATE("tf:op") \TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)

REGISTER_CPU

Name 本质上是 KernelDefBuilder 对象,在内部创建 KernelDef。
KernelDefBuilder::Device 设置设备类型。
KernelDefBuilder::TypeConstraint 设置类型约束。
REGISTER_KERNEL_BUILDER 通过 OpKernelRegistrar 类完成 kernel 函数的注册。

#define REGISTER_CPU(T)                                                   \REGISTER_KERNEL_BUILDER(                                                \Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \CTCGreedyDecoderOp<T>);REGISTER_CPU(float);
REGISTER_CPU(double);

CTCGreedyDecoderOp

OpKernelConstruction::GetAttr 获取属性值。

template <typename T>
class CTCGreedyDecoderOp : public OpKernel {public:explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));OP_REQUIRES_OK(ctx, ctx->GetAttr("blank_index", &blank_index_));}

CTCGreedyDecoderOp::Compute

#mermaid-svg-2zoUCXJ7YGTsGPqG {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .error-icon{fill:#552222;}#mermaid-svg-2zoUCXJ7YGTsGPqG .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-2zoUCXJ7YGTsGPqG .marker{fill:#333333;stroke:#333333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .marker.cross{stroke:#333333;}#mermaid-svg-2zoUCXJ7YGTsGPqG svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-2zoUCXJ7YGTsGPqG .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .cluster-label text{fill:#333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .cluster-label span{color:#333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .label text,#mermaid-svg-2zoUCXJ7YGTsGPqG span{fill:#333;color:#333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .node rect,#mermaid-svg-2zoUCXJ7YGTsGPqG .node circle,#mermaid-svg-2zoUCXJ7YGTsGPqG .node ellipse,#mermaid-svg-2zoUCXJ7YGTsGPqG .node polygon,#mermaid-svg-2zoUCXJ7YGTsGPqG .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-2zoUCXJ7YGTsGPqG .node .label{text-align:center;}#mermaid-svg-2zoUCXJ7YGTsGPqG .node.clickable{cursor:pointer;}#mermaid-svg-2zoUCXJ7YGTsGPqG .arrowheadPath{fill:#333333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-2zoUCXJ7YGTsGPqG .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-2zoUCXJ7YGTsGPqG .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-2zoUCXJ7YGTsGPqG .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-2zoUCXJ7YGTsGPqG .cluster text{fill:#333;}#mermaid-svg-2zoUCXJ7YGTsGPqG .cluster span{color:#333;}#mermaid-svg-2zoUCXJ7YGTsGPqG div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-2zoUCXJ7YGTsGPqG :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;}

CTCGreedyDecoderOp::Compute
CTCDecodeHelper::ValidateInputsGenerateOutputs
CTCDecodeHelper::StoreAllDecodedSequences

输入为 Tensor 指针。
OpOutputList 是由单个命名输出组成的输出张量列表。
CTCDecodeHelper::ValidateInputsGenerateOutputs 验证输入并生成输出张量。
TensorShape 表示一个张量的形状。

  void Compute(OpKernelContext* ctx) override {const Tensor* inputs;const Tensor* seq_len;Tensor* log_prob = nullptr;OpOutputList decoded_indices;OpOutputList decoded_values;OpOutputList decoded_shape;OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(ctx, &inputs, &seq_len, &log_prob, &decoded_indices,&decoded_values, &decoded_shape));const TensorShape& inputs_shape = inputs->shape();

TTypes::UnalignedConstMatrix 是 TensorMap<Tensor<data_type, rank>> 类型,用于在代码的另一部分分配和拥有的内存之上创建张量。它允许将任何分配的内存视为张量。此类的实例不拥有存储数据的内存。TensorMap 不可调整大小,因为它不拥有存储其数据的内存。
TensorShapeBase::dim_size 返回指定维度的大小。

    std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;const int64_t max_time = inputs_shape.dim_size(0);const int64_t batch_size = inputs_shape.dim_size(1);const int64_t num_classes_raw = inputs_shape.dim_size(2);OP_REQUIRES(ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),errors::InvalidArgument("num_classes cannot exceed max int"));const int num_classes = static_cast<const int>(num_classes_raw);

Tensor::tensor 返回嵌套定义的 TTypes:Tensor 对象。
把每个时间片上的数据构造为 TTypes::UnalignedConstMatrix,追加到input_list_t数组。
Tensor::vec 返回一个一维 TTypes::Vec。
Tensor::matrix 返回一个二维 TTypes::Matrix。
Tensor::setZero 将log_prob_t清零。

    auto inputs_t = inputs->tensor<T, 3>();input_list_t.reserve(max_time);for (std::size_t t = 0; t < max_time; ++t) {input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,batch_size, num_classes);}auto seq_len_t = seq_len->vec<int32>();auto log_prob_t = log_prob->matrix<T>();log_prob_t.setZero();int blank_index =(blank_index_ < 0) ? num_classes + blank_index_ : blank_index_;OP_REQUIRES(ctx, FastBoundsCheck(blank_index, num_classes),errors::InvalidArgument("blank_index expected to be between ",-num_classes, " and ", num_classes - 1," but was ", blank_index_));

decode函数在循环中处理单 batch 的数据。
sequences存储每个批次的每个路径的解码值,所以是三层嵌套。GreedyDecoder 只生成一条路径。
seq_len_t数组中获取每个序列的长度。
input_list_t[t]的形状为[batch_size, num_classes],RowMax 找到当前批次的最大概率值及其对应索引。
log_prob_t累积其负数和。
如果不是空白索引且满足重复过滤条件,则添加到路径中。

    // Perform best path decodingstd::vector<std::vector<std::vector<int> > > sequences(batch_size);auto decode = [&](const int64_t begin, const int64_t end) {for (int b = begin; b < end; ++b) {sequences[b].resize(1);auto &sequence = sequences[b][0];int prev_indices = -1;for (int t = 0; t < seq_len_t(b); ++t) {int max_class_indices;OP_REQUIRES(ctx, input_list_t[t].dimension(1) > 0,errors::InvalidArgument("Invalid input dimensions."));log_prob_t(b, 0) +=-RowMax<T>(input_list_t[t], b, &max_class_indices);if (max_class_indices != blank_index &&!(merge_repeated_ && max_class_indices == prev_indices)) {sequence.push_back(max_class_indices);}prev_indices = max_class_indices;}}};

DeviceBase::tensorflow_cpu_worker_threads 返回嵌套定义的结构体 DeviceBase::CpuWorkerThreads,其中存储了线程数和线程池指针。
Shard 函数。
CTCDecodeHelper::StoreAllDecodedSequences 将sequences转换为3个 OpOutputList。

    const int64_t kCostPerUnit = 50 * max_time * num_classes;const int64_t total = batch_size;const DeviceBase::CpuWorkerThreads& worker_threads =*ctx->device()->tensorflow_cpu_worker_threads();Shard(worker_threads.num_threads, worker_threads.workers, total,kCostPerUnit, decode);OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(sequences, &decoded_indices, &decoded_values, &decoded_shape));}

CTCDecodeHelper 用于转换并保存结果。
TF_DISALLOW_COPY_AND_ASSIGN 禁止拷贝构造和赋值构造。

 private:CTCDecodeHelper decode_helper_;bool merge_repeated_;int blank_index_;TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
};

CTCDecodeHelper

 public:CTCDecodeHelper() : top_paths_(1) {}inline int GetTopPaths() const { return top_paths_; }void SetTopPaths(int tp) { top_paths_ = tp; }

CTCDecodeHelper::ValidateInputsGenerateOutputs

#mermaid-svg-iEbt2kWPvAbcX3zC {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-iEbt2kWPvAbcX3zC .error-icon{fill:#552222;}#mermaid-svg-iEbt2kWPvAbcX3zC .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-iEbt2kWPvAbcX3zC .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-iEbt2kWPvAbcX3zC .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-iEbt2kWPvAbcX3zC .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-iEbt2kWPvAbcX3zC .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-iEbt2kWPvAbcX3zC .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-iEbt2kWPvAbcX3zC .marker{fill:#333333;stroke:#333333;}#mermaid-svg-iEbt2kWPvAbcX3zC .marker.cross{stroke:#333333;}#mermaid-svg-iEbt2kWPvAbcX3zC svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-iEbt2kWPvAbcX3zC .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-iEbt2kWPvAbcX3zC .cluster-label text{fill:#333;}#mermaid-svg-iEbt2kWPvAbcX3zC .cluster-label span{color:#333;}#mermaid-svg-iEbt2kWPvAbcX3zC .label text,#mermaid-svg-iEbt2kWPvAbcX3zC span{fill:#333;color:#333;}#mermaid-svg-iEbt2kWPvAbcX3zC .node rect,#mermaid-svg-iEbt2kWPvAbcX3zC .node circle,#mermaid-svg-iEbt2kWPvAbcX3zC .node ellipse,#mermaid-svg-iEbt2kWPvAbcX3zC .node polygon,#mermaid-svg-iEbt2kWPvAbcX3zC .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-iEbt2kWPvAbcX3zC .node .label{text-align:center;}#mermaid-svg-iEbt2kWPvAbcX3zC .node.clickable{cursor:pointer;}#mermaid-svg-iEbt2kWPvAbcX3zC .arrowheadPath{fill:#333333;}#mermaid-svg-iEbt2kWPvAbcX3zC .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-iEbt2kWPvAbcX3zC .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-iEbt2kWPvAbcX3zC .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-iEbt2kWPvAbcX3zC .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-iEbt2kWPvAbcX3zC .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-iEbt2kWPvAbcX3zC .cluster text{fill:#333;}#mermaid-svg-iEbt2kWPvAbcX3zC .cluster span{color:#333;}#mermaid-svg-iEbt2kWPvAbcX3zC div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-iEbt2kWPvAbcX3zC :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;}

CTCDecodeHelper::ValidateInputsGenerateOutputs
OpKernelContext::input
OpKernelContext::allocate_output
OpKernelContext::output_list

OpKernelContext::input 根据名字得到对应的输入张量。

  Status ValidateInputsGenerateOutputs(OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,Tensor** log_prob, OpOutputList* decoded_indices,OpOutputList* decoded_values, OpOutputList* decoded_shape) const {Status status = ctx->input("inputs", inputs);if (!status.ok()) return status;status = ctx->input("sequence_length", seq_len);if (!status.ok()) return status;

获取形状和维度信息。
DECLARE_ERROR 生成和使用错误状态。
TensorShapeUtils::IsVector 根据维度信息判断是否为向量。

    const TensorShape& inputs_shape = (*inputs)->shape();if (inputs_shape.dims() != 3) {return errors::InvalidArgument("inputs is not a 3-Tensor");}if (inputs_shape.num_elements() == 0) {return errors::InvalidArgument("inputs must not be empty");}const int64_t max_time = inputs_shape.dim_size(0);const int64_t batch_size = inputs_shape.dim_size(1);if (max_time == 0) {return errors::InvalidArgument("max_time is 0");}if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {return errors::InvalidArgument("sequence_length is not a vector");}if (!(batch_size == (*seq_len)->dim_size(0))) {return errors::FailedPrecondition("len(sequence_length) != batch_size.  ","len(sequence_length):  ", (*seq_len)->dim_size(0)," batch_size: ", batch_size);}auto seq_len_t = (*seq_len)->vec<int32>();for (int b = 0; b < batch_size; ++b) {if (!(seq_len_t(b) <= max_time)) {return errors::FailedPrecondition("sequence_length(", b,") <= ", max_time);}}

OpKernelContext::allocate_output 分配log_prob的空间。
OpKernelContext::output_list 根据名称得到对应的 OpOutputList。

    Status s = ctx->allocate_output("log_probability", TensorShape({batch_size, top_paths_}), log_prob);if (!s.ok()) return s;s = ctx->output_list("decoded_indices", decoded_indices);if (!s.ok()) return s;s = ctx->output_list("decoded_values", decoded_values);if (!s.ok()) return s;s = ctx->output_list("decoded_shape", decoded_shape);if (!s.ok()) return s;return Status::OK();}

CTCDecodeHelper::StoreAllDecodedSequences

sequences以3个输出变量decoded_indicesdecoded_valuesdecoded_shape表示。
top_paths_为最优路径的数量。
对于每个 batch 序列,计算每个最优路径的条目数。

  // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".Status StoreAllDecodedSequences(const std::vector<std::vector<std::vector<int> > >& sequences,OpOutputList* decoded_indices, OpOutputList* decoded_values,OpOutputList* decoded_shape) const {// Calculate the total number of entries for each pathconst int64_t batch_size = sequences.size();std::vector<int64_t> num_entries(top_paths_, 0);// Calculate num_entries per pathfor (const auto& batch_s : sequences) {CHECK_EQ(batch_s.size(), top_paths_);for (int p = 0; p < top_paths_; ++p) {num_entries[p] += batch_s[p].size();}}

对于每个最优路径,根据num_entries数组中对应的条目数申请内存。
OpOutputList::allocate 调用 OpKernelContext::allocate_output 创建 Tensor 并返回其指针。
indices_tvalues_tshape_t分别为最优路径的索引、标签值和二维形状。

    for (int p = 0; p < top_paths_; ++p) {Tensor* p_indices = nullptr;Tensor* p_values = nullptr;Tensor* p_shape = nullptr;const int64_t p_num = num_entries[p];Status s =decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);if (!s.ok()) return s;s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);if (!s.ok()) return s;s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);if (!s.ok()) return s;auto indices_t = p_indices->matrix<int64_t>();auto values_t = p_values->vec<int64_t>();auto shape_t = p_shape->vec<int64_t>();int64_t max_decoded = 0;int64_t offset = 0;

对于每个 batch,p_batch为序列的最优路径。num_decoded为路径长度。拷贝到values_t中。
indices_t中填btoffset为不同 batch 的偏移。

      for (int64_t b = 0; b < batch_size; ++b) {auto& p_batch = sequences[b][p];int64_t num_decoded = p_batch.size();max_decoded = std::max(max_decoded, num_decoded);if (num_decoded > 0) {DCHECK_NE(values_t.data(), nullptr)<< "values_t should not be nullptr: p_num=" << p_num<< " num_decoded=" << num_decoded;DCHECK_LT(offset, values_t.size())<< "offset should be smaller than values_t.size()";std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));}for (int64_t t = 0; t < num_decoded; ++t, ++offset) {indices_t(offset, 0) = b;indices_t(offset, 1) = t;}}shape_t(0) = batch_size;shape_t(1) = max_decoded;}return Status::OK();}
 private:int top_paths_;TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);

OpKernelContext::input

根据名字获取索引,然后设置到 tensorflow::gtl::InlinedVector 内元素 TensorValue 的张量。

  int index;TF_RETURN_IF_ERROR(get_input_index(name, &index));if (input_is_ref(index)) {return errors::InvalidArgument("OpKernel used ref input name '", name,"' when non-ref input was expected");}*tensor = (*params_->inputs)[index].tensor;return Status::OK();

RowMax

UnalignedConstMatrix
CHECK_LT
找到矩阵中指定行的最大值,返回最大值并记录其列索引到c中。

template <typename T>
inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,int* c) {*c = 0;CHECK_LT(0, m.dimension(1));auto p = m(r, 0);for (int i = 1; i < m.dimension(1); ++i) {if (m(r, i) > p) {p = m(r, i);*c = i;}}return p;
}

Shard

#mermaid-svg-tUH9BR9yjNTeHCbY {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-tUH9BR9yjNTeHCbY .error-icon{fill:#552222;}#mermaid-svg-tUH9BR9yjNTeHCbY .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-tUH9BR9yjNTeHCbY .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-tUH9BR9yjNTeHCbY .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-tUH9BR9yjNTeHCbY .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-tUH9BR9yjNTeHCbY .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-tUH9BR9yjNTeHCbY .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-tUH9BR9yjNTeHCbY .marker{fill:#333333;stroke:#333333;}#mermaid-svg-tUH9BR9yjNTeHCbY .marker.cross{stroke:#333333;}#mermaid-svg-tUH9BR9yjNTeHCbY svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-tUH9BR9yjNTeHCbY .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-tUH9BR9yjNTeHCbY .cluster-label text{fill:#333;}#mermaid-svg-tUH9BR9yjNTeHCbY .cluster-label span{color:#333;}#mermaid-svg-tUH9BR9yjNTeHCbY .label text,#mermaid-svg-tUH9BR9yjNTeHCbY span{fill:#333;color:#333;}#mermaid-svg-tUH9BR9yjNTeHCbY .node rect,#mermaid-svg-tUH9BR9yjNTeHCbY .node circle,#mermaid-svg-tUH9BR9yjNTeHCbY .node ellipse,#mermaid-svg-tUH9BR9yjNTeHCbY .node polygon,#mermaid-svg-tUH9BR9yjNTeHCbY .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-tUH9BR9yjNTeHCbY .node .label{text-align:center;}#mermaid-svg-tUH9BR9yjNTeHCbY .node.clickable{cursor:pointer;}#mermaid-svg-tUH9BR9yjNTeHCbY .arrowheadPath{fill:#333333;}#mermaid-svg-tUH9BR9yjNTeHCbY .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-tUH9BR9yjNTeHCbY .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-tUH9BR9yjNTeHCbY .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-tUH9BR9yjNTeHCbY .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-tUH9BR9yjNTeHCbY .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-tUH9BR9yjNTeHCbY .cluster text{fill:#333;}#mermaid-svg-tUH9BR9yjNTeHCbY .cluster span{color:#333;}#mermaid-svg-tUH9BR9yjNTeHCbY div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-tUH9BR9yjNTeHCbY :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;}

Shard
ThreadPool::ParallelFor

GetPerThreadMaxParallelism 返回全局变量 per_thread_max_parallelism。
如果小于等于1,则直接执行work任务函数。

  CHECK_GE(total, 0);if (total == 0) {return;}max_parallelism = std::min(max_parallelism, GetPerThreadMaxParallelism());if (max_parallelism <= 1) {// Just inline the whole work since we only have 1 thread (core).work(0, total);return;}

ThreadPool::ParallelFor 线程并行处理。

  if (max_parallelism >= workers->NumThreads()) {workers->ParallelFor(total, cost_per_unit, work);return;}

Sharder::Do 方式已经废弃了。

  Sharder::Do(total, cost_per_unit, work,[&workers](Sharder::Closure c) { workers->Schedule(c); },max_parallelism);

ThreadPool::ParallelFor

调用 ThreadPoolDevice::parallelFor 函数来处理。

  CHECK_GE(total, 0);CHECK_EQ(total, (int64_t)(Eigen::Index)total);threadpool_device_->parallelFor(total, Eigen::TensorOpCost(0, 0, cost_per_unit),[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });

ThreadPoolDevice

// CPU device implementation.
class ThreadPoolDevice : public LocalDevice {public:ThreadPoolDevice(const SessionOptions& options, const string& name,Bytes memory_limit, const DeviceLocality& locality,Allocator* allocator);~ThreadPoolDevice() override;Allocator* GetAllocator(AllocatorAttributes attr) override;Allocator* GetScopedAllocator(AllocatorAttributes attr,int64_t step_id) override;ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {return scoped_allocator_mgr_.get();}Status MakeTensorFromProto(const TensorProto& tensor_proto,const AllocatorAttributes alloc_attrs,Tensor* tensor) override;void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,const DeviceContext* device_context,StatusCallback done) override;Status Sync() override { return Status::OK(); }void Compute(OpKernel* op_kernel, OpKernelContext* context) override;void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,AsyncOpKernel::DoneCallback done) override;private:void LogInputs(OpKernel* op_kernel, OpKernelContext* context);void LogOutputs(OpKernel* op_kernel, OpKernelContext* context);Allocator* allocator_;  // Not ownedstd::unique_ptr<ScopedAllocatorMgr> scoped_allocator_mgr_;NodeFileWriter* node_file_writer_ = nullptr;  // not owned
};

参考资料:

  • Does NOT tf.nn.ctc_beam_search_decoder() support GPU in TensorFlow2?
  • tf.nn.ctc_greedy_decoder
  • How does tf.nn.ctc_greedy_decoder generates output sequences in tensorflow?
  • Source code for decoders.fc_decoders
  • python重写tf.nn.ctc_greedy_decoder
  • ctc_greedy_decoder
  • mindspore.ops.CTCGreedyDecoder
  • ngraph.opset1.ctc_greedy_decoder
  • 对《CTC 原理及实现》中的一些算法的解释
  • TensorFlow源码解读之greedy search及beam search
  • TensorFLow 中 CTC 的相关函数介绍
  • CTC的Decode算法-Prefix Beam Search
  • TextBoxes_plusplus/src/caffe/layers/ctc_decoder_layer.cpp
  • tensorflow源码解析之common_runtime-executor-下
  • How can I run a lambda immediately?
  • tensorflow源码笔记(Eigen3部分)
  • 从零开始编写深度学习库(四)Eigen::Tensor学习使用及代码重构
  • Eigen Tensors
  • 2.Eigen Tensor详解【一】
  • 3.Eigen Tensor详解【二】
  • TensorFlow实现自定义Op
  • tensorflow源码分析——CTC
  • CTC理论和实战
  • Multiple paths leading to same label during CTC Beam Search
  • Beam Search Decoding in CTC-trained Neural Networks
  • tensorflow源码阅读-opkernel注册
  • tensorflow源码解析之framework-device
  • 读TensorFlow 源码笔记(1): tensorflow的多线程机制eigen::threadpool
  • slower performance in container when using V100
  • tensorflow thread::ThreadPool简单使用分析
  • TensorFlow源码分析(7):TensorFlow的DeviceBase类
  • TensorFlow源码分析(6):Eigen的Simple Thread Pool实现原理
  • TensorFlow学习笔记(11):数据操作指南
  • tf.SparseTensor

TensorFlow 中的 CTCGreedyDecoder相关推荐

  1. TensorFlow中的语义分割套件

    TensorFlow中的语义分割套件 描述 该存储库用作语义细分套件.目标是轻松实现,训练和测试新的语义细分模型!完成以下内容: 训练和测试方式 资料扩充 几种最先进的模型.轻松随插即用 能够使用任何 ...

  2. TensorFlow中的计算图

    作者 | stephenDC 来源 | 大数据与人工智能(ID:ai-big-data) 1 什么是计算图? 一个机器学习任务的核心是模型的定义以及模型的参数求解方式,对这两者进行抽象之后,可以确定一 ...

  3. 如何使用TensorFlow中的Dataset API

    翻译 | AI科技大本营 参与 | zzq 审校 | reason_W 本文已更新至TensorFlow1.5版本 我们知道,在TensorFlow中可以使用feed-dict的方式输入数据信息,但是 ...

  4. tensorflow中的向量范数

    向量范数(Vector Norm):是表征向量"长度"的一种度量方法,其中可以推广到张量上. 在tensorflow中可以通过tf.norm(x,ord)求解张量的L1,L2,∞等 ...

  5. tensorflow中的命令行参数介绍

    1.tensorflow中的tf.flags参数介绍 #!/usr/bin/env python # -*- coding: utf-8 -*- # @Date : 2019-01-20 21:39: ...

  6. tensorflow中的变量管理

    import tensorflow as tf# variable_scope()示例 """ tensorflow中通过变量名称获取变量的机制主要是通过tf.get_v ...

  7. TensorFlow中Session.run和Tensor.eval的区别

    之前在TensorFlow中运行代码时,在会话中会需要运行节点,会碰到两种方式:Session.run()和Tensor.eval(),刚开始不太懂这两者之间的差异,最后通过查找官方文档和一些资料了解 ...

  8. tensorflow中的tf.summary.image

    tensorflow中的tf.summary.image tf.summary.image(name,#生成的节点的名称.也将作为TensorBoard中的系列名称tensor,#uint8或者flo ...

  9. tensorflow中Tensorboard的用法

    tensorflow中Tensorboard的用法 下面代码定义了一个简单的用于实现常量加法和乘法的计算图. import tensorflow as tf input1 = tf.constant( ...

最新文章

  1. .net下的富文本编辑器FCKeditor的配置方法(图)原创
  2. 技术图文:双指针在求解算法题中的应用
  3. mysql表结构说明只能为1 8_SQL基础
  4. python文件对象提供了3个读方法、分别是-python3 IO编程:文件读写
  5. Redis设计与实现RDB持久化
  6. css初始化_CodeMirror项目【在线编辑器】--项目初始化
  7. Gradle:我们需要另一个构建工具吗?
  8. java实体字节属性定义_Java字节码方法表与属性表详解
  9. 计算机视觉论文-2021-06-18
  10. html equls比较方法,编写高质量equals方法
  11. VeraCrypt文件硬盘加密使用教程
  12. 中文自然语言处理入门实战
  13. sass基础语法-Mixin混合器,%placeholder占位符继承之间的区别
  14. c语言程序烧写步骤,单片机烧写程序步骤
  15. MTF的倾斜边缘计算方法
  16. win10电脑wifi显示无法连接服务器,Win10怎么连Wifi?解决Win10无法连接wifi无线网络的方法图文详解...
  17. 深入理解java虚拟机(六)GC垃圾回收-低延迟垃圾收集器(Shenandoah、ZGC)
  18. 计算机上的波特率标准,电子信号术语-波特率9600计算单位是波特/每秒(B/s)
  19. 关于Mastering-OpenCV3第二版的代码跑通--关于PCL的一些问题
  20. Python中inplace参数

热门文章

  1. Python爬取大乐透
  2. 赌博-值得玩耍的棋牌 3
  3. 自己真实经历过面试题
  4. Complete Tripartite
  5. 基于同态加密的隐私计算技术在基因序列演化分析场景的应用
  6. 应用系统外部接口数据稽核问题分析经历及经验分析
  7. 应用宝sdk接入流程与注意事项总结
  8. 查全率,查准率,准确率区别?
  9. Linux下的softlink和hardlink
  10. Elasticsearch - Elasticsearch 优化(十五)