提问者:小点点

错误:CSV文件头在scikit学习库中参与了决策树计算


我运行下面的代码形式创建决策树的Scikit学习库。

import numpy as np
from sklearn.model_selection import train_test_split
import os
from sklearn.tree import export_graphviz
import graphviz

#progress 1
path="/mnt/d/TestDecisionTree/datasets"
os.chdir(path)
os.getcwd()

#progress 2
dataset=np.loadtxt("internetlogit.csv", delimiter=",")
x=dataset[:,0:5]
y=dataset[:,5]

#progress 3
from sklearn.tree import DecisionTreeRegressor
X_train, X_test, y_train, y_test = train_test_split(x, y)
tree = DecisionTreeRegressor().fit(X_train,y_train)

#progress 4
print("Training set accuracy: {:.3f}".format(tree.score(X_train, y_train)))
print("Test set accuracy: {:.3f}".format(tree.score(X_test, y_test)))

#progress 5
dtree = tree.predict(x)
print(dtree)

#progress 6
percentageerror_tree=((y-dtree)/dtree)*100
percentageerror_tree

#progress 7
np.mean(percentageerror_tree)

#progress 8
export_graphviz(tree,out_file="result/tree.dot")

with open("result/tree.dot") as f:
    dot_graph = f.read()

graphviz.Source(dot_graph)

我的示例数据是internetlogit.csv文件中的以下数据集

age,gender,webpages,videohours,income,usage
36,0,32,0.061388889,6021,0
33,0,49,8.516666667,10239,1
46,1,22,0,1374,0
53,0,16,2.762222222,5376,0
27,1,30,0,1393,0
21,1,23,2.641111111,4866,0
42,0,30,0,1673,0
...

但我在《进步2》中发现了这个错误。

ValueError: could not convert string to float: 'age'

这意味着CSV文件的头参与决策树计算。但是,不应该是这样。我怎样才能解决这个问题?

感谢任何帮助。


共1个答案

匿名用户

pandas将是最容易的插入式修复,例如:

import pandas as pd
from io import StringIO

csv_file = StringIO("""
age,gender,webpages,videohours,income,usage
36,0,32,0.061388889,6021,0
33,0,49,8.516666667,10239,1
46,1,22,0,1374,0
53,0,16,2.762222222,5376,0
27,1,30,0,1393,0
21,1,23,2.641111111,4866,0
42,0,30,0,1673,0
""")

df = pd.read_csv(csv_file)

y = df["usage"]
x = df.drop(["usage"], axis=1)

from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(x, y)
tree = DecisionTreeRegressor().fit(X_train,y_train)
tree.fit(X_train, y_train)
print(tree)

由于列名不能解释为字符串,所以在发布的代码的第13行引发了ValueError

如果您不想使用pandas,您还可以将skiprows传递到np。loadtxt

dataset = np.loadtxt(csv_file, delimiter=",", skiprows=2)
x = dataset[:,0:5]
y = dataset[:,5]