提问者:小点点

DNN累加器训练中的错误


我在估计器DNNRegressor的帮助下构建回归模型。下面是代码

import tensorflow as tf

DATA_PATH = 'train_data/train_1.csv'
BATCH_SIZE = 5
N_FEATURES = 3963

def batch_generator(filenames):
    """ filenames is the list of files you want to read from. 
    In this case, it contains only heart.csv
    """
    filename_queue = tf.train.string_input_producer(filenames)
    reader = tf.TextLineReader(skip_header_lines=1) # skip the first line in the file
    _, value = reader.read(filename_queue)
    record_defaults = [[1.0] for _ in range(N_FEATURES)]

    # read in the rows of data
    content = tf.decode_csv(value, record_defaults=record_defaults) 

    # pack all features into a tensor
    features = tf.stack(content[:N_FEATURES])

    # assign the last column to label
    label = content[1]

    # minimum number elements in the queue after a dequeue, used to ensure 
    # that the samples are sufficiently mixed
    # I think 10 times the BATCH_SIZE is sufficient
    min_after_dequeue = 10 * BATCH_SIZE

    # the maximum number of elements in the queue
    capacity = 20 * BATCH_SIZE

    # shuffle the data to generate BATCH_SIZE sample pairs
    data_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=BATCH_SIZE, 
                                        capacity=capacity, min_after_dequeue=min_after_dequeue)

    return data_batch, label_batch

def generate_batches():
    regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,hidden_units=[10,10,10],model_dir='alg_model4')
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in range(4): # generate 10 batches
            regressor.train(input_fn=sess.run(input_fn()),steps=2)
        coord.request_stop()
        coord.join(threads)

def main():
    generate_batches()


if __name__ == '__main__':
    main()

以下是流程:-

  • 首先,我从包含前缀为train_的多个文件的目录中读取数据。
  • 图案像train_*. csv
  • 总共包含3963列。
  • 第二列是因变量所有类型都是整数
  • 我需要读取固定大小的批量数据集,并将其输入DNNRegressor以训练模型

问题是它抛出以下输出时出错:-信息:tensorflow:使用默认配置。信息:tensorflow:使用配置:{“日志”步骤计数步骤:100,“保存检查点”步骤:5,“保存检查点”步骤:600,“随机种子”:1,“保存摘要步骤”:100,“保存模型步骤”:“alg模型4”,“保存检查点”步骤:无”,“会话配置”:无”,“每小时保留检查点”:10000(TensorShape([Dimension(None)),维度(3963),维度,TensorShape([维度(无)])信息:tensorflow:向协调器报告的错误:,已取消出列操作[[Node:ReaderReadV2_7=ReaderReadV2[_device=“/job:localhost/replica:0/task:0/cpu:0”](TextLineReaderV2_7,input_producer_7)]--------------------------------------------------------------------------------------类型错误回溯(最近一次调用)/usr/lib/python3。5/检查。getfullargspec(func)1088中的py
跳过\u-bound\u-arg=False-

/usr/lib/python3.5/inspect.py在_signature_from_callable(obj,follow_wrapper_chains,skip_bound_arg, sigcls)2155如果不可调用(obj):-

TypeError:(数组([[0,1,0.,…,0,0,0.]),[ 0., 1., 0., ..., 0., 0., 0.], [ 0., 1., 0., ..., 0., 0., 0.], [ 0., 1., 0., ..., 0., 0., 0.], [ 0., 1., 0., ..., 1., 0., 0.]], dtype=float32),数组([42612203412040491414])不是可调用对象

上述例外是以下例外的直接原因:

如果name='main',则在()4 5中的TypeError回溯(最近一次调用last):----

在main()1 def main(): ----

在generate_batches()中,5个线程=tf。火车为范围(4)内的?启动?队列?运行程序(coord=coord)6:#生成10个批次----

/usr/local/lib/python3。5/dist包/tensorflow/python/estimator/estimator。列车内py(自身、输入、挂钩、台阶、最大台阶)239挂钩。附加(训练.停止Stepheook(步数,最大步数))240--

/usr/local/lib/python3。5/dist包/tensorflow/python/estimator/estimator。py in_train_模型(self,input_fn,hooks)626全局步骤张量=self_创建和断言步骤(g)627要素,标签=自我_从输入获取特征和标签(--

/usr/local/lib/python3。5/dist包/tensorflow/python/estimator/estimator。py in_get_features_和_labels_from_input_fn(self,input_fn,mode)497 498 def_get_features_和_labels_from_input_fn(self,input_fn,mode):--

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py_call_input_fn(解析参数失败)576""577 del模式#未使用-

/usr/local/lib/python3。5/dist包/tensorflow/python/estimator/util。fn#U参数(fn)55 56#句柄函数中的py---

/usr/local/lib/python3。5/dist-packages/tensorflow/python/util/tf_-inspect。getargspec(object)43修饰符中的py,target=tf_修饰符。展开(对象)44返回下一个((d.decorator_argspec for d in decorators---

/usr/lib/python3.5/inspect.py在getArgSpec(func)1041
stacklevel=2)1042 args, varargs, varkw,默认值, kWonlyargs, kWonlydefax, ann=\-

/usr/lib/python3.5/inspect.py在getful扩大规格(func)1093

-

不支持的可调用


共1个答案

匿名用户

不确定input_fn来自哪里,但是sess.run(input_fn())看起来不正确。估计器将评估input_fn