下载数据集

tr_path = 'covid.train.csv'  # path to training data
tt_path = 'covid.test.csv' # path to testing data

!gdown --id '19CCyCgJrUxtvgZF53vnctJiOJ23T5mqF' --output covid.train.csv
!gdown --id '1CE240jLm2npU-tdz81-oVKEF3T2yfT1O' --output covid.test.csv
Downloading...
From: https://drive.google.com/uc?id=19CCyCgJrUxtvgZF53vnctJiOJ23T5mqF
To: /Users/baikal/machineLearning/lessonOne/covid.train.csv
100%|███████████████████████████████████████| 2.00M/2.00M [00:17<00:00, 115kB/s]
Downloading...
From: https://drive.google.com/uc?id=1CE240jLm2npU-tdz81-oVKEF3T2yfT1O
To: /Users/baikal/machineLearning/lessonOne/covid.test.csv
100%|█████████████████████████████████████████| 651k/651k [00:05<00:00, 125kB/s]

需要使用google下载工具下载google drive上的文件,安装方法:
pip install gdown

查看数据集

#下面三个包是新增的
from sklearn.model_selection import train_test_split
import pandas as pd
import pprint as pp

# 读取训练数据
data_tr = pd.read_csv(tr_path)
# 读取测试数据
data_tt = pd.read_csv(tt_path)
# 读取训练数据前5行
data_tr.head(5)

ALAKAZARCACOCTFLGAID...restaurant.2spent_time.2large_event.2public_transit.2anxious.2depressed.2felt_isolated.2worried_become_ill.2worried_finances.2tested_positive.2
01.00.00.00.00.00.00.00.00.00.0...23.81241143.43042316.1515271.60263515.40944912.08868816.70208653.99154943.60422920.704935
11.00.00.00.00.00.00.00.00.00.0...23.68297443.19631316.1233861.64186315.23006311.80904716.50697354.18552142.66576621.292911
21.00.00.00.00.00.00.00.00.00.0...23.59398343.36220016.1599711.67752315.71720712.35591816.27329453.63706942.97241721.166656
31.00.00.00.00.00.00.00.00.00.0...22.57699242.95457415.5443731.57803015.29565012.21812316.04550452.44622342.90747219.896607
41.00.00.00.00.00.00.00.00.00.0...22.09143343.29095715.2146551.64166714.77880212.41725616.13423852.56031543.32198520.178428

5 rows × 94 columns

# 读取测试数据前5行
data_tt.head(5)

idALAKAZARCACOCTFLGA...shop.2restaurant.2spent_time.2large_event.2public_transit.2anxious.2depressed.2felt_isolated.2worried_become_ill.2worried_finances.2
000.00.00.00.00.01.00.00.00.0...52.0710908.62400129.3747925.3914132.75480419.69509813.68564524.74783766.19495044.873473
110.00.00.00.00.00.00.00.00.0...58.74246121.72018741.3757849.4501793.15008822.07571517.30207723.55962257.01500938.372829
220.00.00.00.00.00.00.00.00.0...59.10904520.12395940.0725568.7815222.88820923.92087018.34250624.99334155.29149838.907257
330.00.00.00.00.00.00.00.00.0...55.44226716.08352936.9776125.1992862.57534721.07380012.08717118.60872367.03619743.142779
440.00.00.00.00.00.00.00.00.0...60.58878319.50301042.63123611.5497718.53055115.89657511.78163415.06522861.19651843.574676

5 rows × 94 columns

# 查看有多少列特征
data_tr.columns
Index(['id', 'AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'FL', 'GA', 'ID', 'IL',
       'IN', 'IA', 'KS', 'KY', 'LA', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'NE',
       'NV', 'NJ', 'NM', 'NY', 'NC', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'TX',
       'UT', 'VA', 'WA', 'WV', 'WI', 'cli', 'ili', 'hh_cmnty_cli',
       'nohh_cmnty_cli', 'wearing_mask', 'travel_outside_state',
       'work_outside_home', 'shop', 'restaurant', 'spent_time', 'large_event',
       'public_transit', 'anxious', 'depressed', 'felt_isolated',
       'worried_become_ill', 'worried_finances', 'tested_positive', 'cli.1',
       'ili.1', 'hh_cmnty_cli.1', 'nohh_cmnty_cli.1', 'wearing_mask.1',
       'travel_outside_state.1', 'work_outside_home.1', 'shop.1',
       'restaurant.1', 'spent_time.1', 'large_event.1', 'public_transit.1',
       'anxious.1', 'depressed.1', 'felt_isolated.1', 'worried_become_ill.1',
       'worried_finances.1', 'tested_positive.1', 'cli.2', 'ili.2',
       'hh_cmnty_cli.2', 'nohh_cmnty_cli.2', 'wearing_mask.2',
       'travel_outside_state.2', 'work_outside_home.2', 'shop.2',
       'restaurant.2', 'spent_time.2', 'large_event.2', 'public_transit.2',
       'anxious.2', 'depressed.2', 'felt_isolated.2', 'worried_become_ill.2',
       'worried_finances.2', 'tested_positive.2'],
      dtype='object')
# id列用不到,去除
data_tr.drop(['id'], axis = 1, inplace = True)
data_tt.drop(['id'], axis = 1, inplace = True)
# 取特征列
cols = list(data_tr.columns)
pp.pprint(data_tr.columns)
Index(['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'FL', 'GA', 'ID', 'IL', 'IN',
       'IA', 'KS', 'KY', 'LA', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'NE', 'NV',
       'NJ', 'NM', 'NY', 'NC', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'TX', 'UT',
       'VA', 'WA', 'WV', 'WI', 'cli', 'ili', 'hh_cmnty_cli', 'nohh_cmnty_cli',
       'wearing_mask', 'travel_outside_state', 'work_outside_home', 'shop',
       'restaurant', 'spent_time', 'large_event', 'public_transit', 'anxious',
       'depressed', 'felt_isolated', 'worried_become_ill', 'worried_finances',
       'tested_positive', 'cli.1', 'ili.1', 'hh_cmnty_cli.1',
       'nohh_cmnty_cli.1', 'wearing_mask.1', 'travel_outside_state.1',
       'work_outside_home.1', 'shop.1', 'restaurant.1', 'spent_time.1',
       'large_event.1', 'public_transit.1', 'anxious.1', 'depressed.1',
       'felt_isolated.1', 'worried_become_ill.1', 'worried_finances.1',
       'tested_positive.1', 'cli.2', 'ili.2', 'hh_cmnty_cli.2',
       'nohh_cmnty_cli.2', 'wearing_mask.2', 'travel_outside_state.2',
       'work_outside_home.2', 'shop.2', 'restaurant.2', 'spent_time.2',
       'large_event.2', 'public_transit.2', 'anxious.2', 'depressed.2',
       'felt_isolated.2', 'worried_become_ill.2', 'worried_finances.2',
       'tested_positive.2'],
      dtype='object')
# 看每列数据类型和大小
pp.pprint(data_tr.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2700 entries, 0 to 2699
Data columns (total 94 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   AL                      2700 non-null   float64
 1   AK                      2700 non-null   float64
 2   AZ                      2700 non-null   float64
 3   AR                      2700 non-null   float64
 4   CA                      2700 non-null   float64
 5   CO                      2700 non-null   float64
 6   CT                      2700 non-null   float64
 7   FL                      2700 non-null   float64
 8   GA                      2700 non-null   float64
 9   ID                      2700 non-null   float64
 10  IL                      2700 non-null   float64
 11  IN                      2700 non-null   float64
 12  IA                      2700 non-null   float64
 13  KS                      2700 non-null   float64
 14  KY                      2700 non-null   float64
 15  LA                      2700 non-null   float64
 16  MD                      2700 non-null   float64
 17  MA                      2700 non-null   float64
 18  MI                      2700 non-null   float64
 19  MN                      2700 non-null   float64
 20  MS                      2700 non-null   float64
 21  MO                      2700 non-null   float64
 22  NE                      2700 non-null   float64
 23  NV                      2700 non-null   float64
 24  NJ                      2700 non-null   float64
 25  NM                      2700 non-null   float64
 26  NY                      2700 non-null   float64
 27  NC                      2700 non-null   float64
 28  OH                      2700 non-null   float64
 29  OK                      2700 non-null   float64
 30  OR                      2700 non-null   float64
 31  PA                      2700 non-null   float64
 32  RI                      2700 non-null   float64
 33  SC                      2700 non-null   float64
 34  TX                      2700 non-null   float64
 35  UT                      2700 non-null   float64
 36  VA                      2700 non-null   float64
 37  WA                      2700 non-null   float64
 38  WV                      2700 non-null   float64
 39  WI                      2700 non-null   float64
 40  cli                     2700 non-null   float64
 41  ili                     2700 non-null   float64
 42  hh_cmnty_cli            2700 non-null   float64
 43  nohh_cmnty_cli          2700 non-null   float64
 44  wearing_mask            2700 non-null   float64
 45  travel_outside_state    2700 non-null   float64
 46  work_outside_home       2700 non-null   float64
 47  shop                    2700 non-null   float64
 48  restaurant              2700 non-null   float64
 49  spent_time              2700 non-null   float64
 50  large_event             2700 non-null   float64
 51  public_transit          2700 non-null   float64
 52  anxious                 2700 non-null   float64
 53  depressed               2700 non-null   float64
 54  felt_isolated           2700 non-null   float64
 55  worried_become_ill      2700 non-null   float64
 56  worried_finances        2700 non-null   float64
 57  tested_positive         2700 non-null   float64
 58  cli.1                   2700 non-null   float64
 59  ili.1                   2700 non-null   float64
 60  hh_cmnty_cli.1          2700 non-null   float64
 61  nohh_cmnty_cli.1        2700 non-null   float64
 62  wearing_mask.1          2700 non-null   float64
 63  travel_outside_state.1  2700 non-null   float64
 64  work_outside_home.1     2700 non-null   float64
 65  shop.1                  2700 non-null   float64
 66  restaurant.1            2700 non-null   float64
 67  spent_time.1            2700 non-null   float64
 68  large_event.1           2700 non-null   float64
 69  public_transit.1        2700 non-null   float64
 70  anxious.1               2700 non-null   float64
 71  depressed.1             2700 non-null   float64
 72  felt_isolated.1         2700 non-null   float64
 73  worried_become_ill.1    2700 non-null   float64
 74  worried_finances.1      2700 non-null   float64
 75  tested_positive.1       2700 non-null   float64
 76  cli.2                   2700 non-null   float64
 77  ili.2                   2700 non-null   float64
 78  hh_cmnty_cli.2          2700 non-null   float64
 79  nohh_cmnty_cli.2        2700 non-null   float64
 80  wearing_mask.2          2700 non-null   float64
 81  travel_outside_state.2  2700 non-null   float64
 82  work_outside_home.2     2700 non-null   float64
 83  shop.2                  2700 non-null   float64
 84  restaurant.2            2700 non-null   float64
 85  spent_time.2            2700 non-null   float64
 86  large_event.2           2700 non-null   float64
 87  public_transit.2        2700 non-null   float64
 88  anxious.2               2700 non-null   float64
 89  depressed.2             2700 non-null   float64
 90  felt_isolated.2         2700 non-null   float64
 91  worried_become_ill.2    2700 non-null   float64
 92  worried_finances.2      2700 non-null   float64
 93  tested_positive.2       2700 non-null   float64
dtypes: float64(94)
memory usage: 1.9 MB
None
# WI列是states one-hot编码最后一列,取值为0或1,后面特征分析时需要把states特征删掉
WI_index = cols.index('WI')
# wi列索引 39
WI_index
39
# 从上面可以看出wi 列后面是cli, 所以列索引从40开始, 并查看这些数据分布
'''
loc函数:通过索引 "Index" 中的具体值来取行数据(如取"Index"为"A"的行)
dataFrame.loc[:, :]
iloc函数:通过行号、列号来取行数据(如取第二行的数据)
dataFrame.iloc[:, :] -> dataFrame.iloc[x.begin: x.end, y.begin: y.end]
'''
data_tr.iloc[:, 40:].describe()

cliilihh_cmnty_clinohh_cmnty_cliwearing_masktravel_outside_statework_outside_homeshoprestaurantspent_time...restaurant.2spent_time.2large_event.2public_transit.2anxious.2depressed.2felt_isolated.2worried_become_ill.2worried_finances.2tested_positive.2
count2700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.000000...2700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.0000002700.000000
mean0.9915871.01613629.44249624.32305489.6823228.89449831.70330755.27715316.69434236.283177...16.57829036.07494110.2574742.38573518.06763513.05882819.24328364.83430744.56844016.431280
std0.4202960.4236299.0937388.4467505.3800273.4040274.9289024.5259175.6684796.675206...5.6515836.6551664.6862631.0531472.2500811.6285892.7083396.2200875.2320307.619354
min0.1263210.1324709.9616406.85718170.9509121.25298318.31194143.2201873.63741421.485815...3.63741421.4858152.1186740.72877012.9807868.37053613.40039948.22560333.1138822.338708
25%0.6739290.69751523.20316518.53915386.3095376.17775428.24786551.54720613.31105030.740931...13.20053230.6067116.5325431.71408016.42048511.91416717.32291259.78287640.54998710.327314
50%0.9127470.94029528.95573823.81976190.8194358.28828832.14314055.25726216.37169936.267966...16.22701036.0413899.7003682.19952117.68419712.94874918.76026765.93225943.99763715.646480
75%1.2668491.30204036.10911430.23806193.93711911.58220935.38731558.86613021.39697141.659971...21.20716241.50852013.6025662.73046919.50341914.21432020.71363869.71965148.11828322.535165
max2.5977322.62588556.83228951.55045098.08716018.55232542.35907465.67388928.48822050.606465...28.48822050.60646524.4967118.16227528.57409118.71594428.36627077.70101458.43360040.959495

8 rows × 54 columns

# 查看测试集数据分布,并和训练集数据分布对比,两者特征之间数据分布差异不是很大
data_tt.iloc[:, 40:].describe()

cliilihh_cmnty_clinohh_cmnty_cliwearing_masktravel_outside_statework_outside_homeshoprestaurantspent_time...shop.2restaurant.2spent_time.2large_event.2public_transit.2anxious.2depressed.2felt_isolated.2worried_become_ill.2worried_finances.2
count893.000000893.000000893.000000893.000000893.000000893.000000893.000000893.000000893.000000893.000000...893.000000893.000000893.000000893.000000893.000000893.000000893.000000893.000000893.000000893.000000
mean0.9724570.99180929.07568224.01872989.6375069.00132531.62060755.42298216.55438736.371653...55.26862816.44491636.16589810.2489752.36911517.98814712.99383019.23872364.61992044.411505
std0.4119970.4154689.5962908.9882454.7335493.6556164.7545704.3667805.6888026.203232...4.3505405.6568286.1922744.4988451.1143662.2070221.7131432.6874355.6858654.605268
min0.1395580.1594779.1713156.01474076.8952782.06250018.29919844.0624423.80068421.487077...44.6718913.83744121.3384252.3346550.87398612.6969778.46244413.47620950.21223435.072577
25%0.6733270.68936721.83173017.38549086.5874757.05503928.75517851.72698713.31424231.427591...51.59430113.39176931.3304696.8028601.76037416.40639711.77710117.19731360.35820340.910546
50%0.9252300.93661028.18301423.03574990.1231338.77324331.82638555.75088717.10055636.692799...55.49032516.97541036.2135949.5503932.14646817.71976012.80542419.06865865.14812844.504010
75%1.2512191.26746336.81377231.14186693.38795210.45226235.18492659.18535020.91996141.265159...59.07847520.58437641.07103513.3727312.64531419.42372014.09155121.20569568.99430947.172065
max2.4889672.52226353.18406748.14243397.84322126.59875242.88726363.97900727.43828653.513289...63.77109727.36232152.04537323.3056309.11830227.00356418.96415726.00755776.87105356.442135

8 rows × 53 columns

# For plotting
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
%matplotlib inline
# 肉眼分析cli特征与目标之间相关性
plt.scatter(data_tr.loc[:, 'cli'], data_tr.loc[:, 'tested_positive.2'])
<matplotlib.collections.PathCollection at 0x16c331670>

png

plt.scatter(data_tr.loc[:, 'ili'], data_tr.loc[:, 'tested_positive.2'])
<matplotlib.collections.PathCollection at 0x1380acf40>

png

# cli 和ili两者差不多,所以这两个特征用一个就行
plt.scatter(data_tr.loc[:, 'cli'], data_tr.loc[:, 'ili'])
<matplotlib.collections.PathCollection at 0x13811ca30>

png

#day1 目标值与day3目标值相关性,线性相关的
plt.scatter(data_tr.loc[:, 'tested_positive'], data_tr.loc[:, 'tested_positive.2'])
<matplotlib.collections.PathCollection at 0x13815b730>

png

# day2 目标值与day3目标值相关性,线性相关的
plt.scatter(data_tr.loc[:, 'tested_positive.1'], data_tr.loc[:, 'tested_positive.2'])
<matplotlib.collections.PathCollection at 0x1381ee190>

png

# 上面手动分析太累,还是利用corr方法自动分析
data_tr.iloc[:, 40:].corr()

cliilihh_cmnty_clinohh_cmnty_cliwearing_masktravel_outside_statework_outside_homeshoprestaurantspent_time...restaurant.2spent_time.2large_event.2public_transit.2anxious.2depressed.2felt_isolated.2worried_become_ill.2worried_finances.2tested_positive.2
cli1.0000000.9957350.8934160.882322-0.107406-0.0959640.087305-0.364165-0.143318-0.209020...-0.151291-0.222834-0.060308-0.3740710.2371350.0814560.0983450.2287500.5505640.838504
ili0.9957351.0000000.8897290.878280-0.109015-0.1069340.086355-0.357443-0.142082-0.207210...-0.150141-0.220942-0.061298-0.3638730.2452280.0862290.1042500.2229090.5447760.830527
hh_cmnty_cli0.8934160.8897291.0000000.997225-0.035441-0.0695950.079219-0.472746-0.247043-0.293775...-0.253615-0.300062-0.136937-0.4332760.3075810.1814970.2035770.3502550.5619420.879724
nohh_cmnty_cli0.8823220.8782800.9972251.000000-0.046063-0.0619140.097756-0.465374-0.238106-0.280916...-0.245265-0.287482-0.129474-0.4249960.3178360.1884670.2035990.3454480.5347110.869938
wearing_mask-0.107406-0.109015-0.035441-0.0460631.000000-0.220808-0.735649-0.691597-0.788714-0.807623...-0.785281-0.802659-0.8890210.1334870.204031-0.0677200.4275330.8405280.340101-0.069531
travel_outside_state-0.095964-0.106934-0.069595-0.061914-0.2208081.0000000.2641070.2569110.2884730.349829...0.2880980.3369370.319736-0.2036110.0015920.064425-0.370776-0.131961-0.093096-0.097303
work_outside_home0.0873050.0863550.0792190.097756-0.7356490.2641071.0000000.6319580.7436730.698047...0.7303490.7055330.758575-0.1101760.0182590.075562-0.430307-0.652231-0.3177170.034865
shop-0.364165-0.357443-0.472746-0.465374-0.6915970.2569110.6319581.0000000.8209160.819035...0.8110550.8383580.7872370.130046-0.228007-0.029168-0.496368-0.866789-0.475304-0.410430
restaurant-0.143318-0.142082-0.247043-0.238106-0.7887140.2884730.7436730.8209161.0000000.878576...0.9933580.8761070.909089-0.046081-0.278715-0.074727-0.648631-0.832131-0.430842-0.157945
spent_time-0.209020-0.207210-0.293775-0.280916-0.8076230.3498290.6980470.8190350.8785761.000000...0.8753650.9867130.912682-0.040623-0.1699650.105281-0.517139-0.867460-0.522985-0.252125
large_event-0.042033-0.043535-0.124151-0.116761-0.8949700.3242700.7673050.7818620.9124490.918504...0.9105790.9138140.993111-0.139139-0.2155980.055579-0.565014-0.874083-0.372589-0.052473
public_transit-0.367103-0.356652-0.432142-0.4237730.131350-0.198308-0.1100770.132385-0.043954-0.037282...-0.048799-0.035965-0.1370800.982095-0.055799-0.1675990.001697-0.046611-0.138801-0.448360
anxious0.2738740.2819740.3367480.3440740.232620-0.0231750.013537-0.265503-0.312912-0.209830...-0.327660-0.218920-0.283515-0.0542700.9511960.5395960.5162520.2800870.2179880.173295
depressed0.0980330.1027150.1847390.190062-0.0700220.0585480.075801-0.041607-0.0740590.104628...-0.0659030.1139340.063086-0.1659720.5994230.9531570.592656-0.0556940.0212740.037689
felt_isolated0.1009280.1070790.1981760.1976610.422058-0.376858-0.431247-0.491608-0.642316-0.511772...-0.633869-0.497951-0.5446780.0097420.5263450.6044160.9783030.3956060.1280470.082182
worried_become_ill0.2185020.2129310.3444570.3401920.843990-0.136811-0.656085-0.864583-0.835101-0.870365...-0.831439-0.869933-0.872394-0.0435750.251045-0.0384210.4199400.9929760.4901270.262211
worried_finances0.5376080.5322170.5524310.5240220.354130-0.096444-0.339975-0.489539-0.447892-0.536561...-0.451124-0.536959-0.397443-0.1411400.1525000.0273820.1442300.5069070.9881230.475462
tested_positive0.8391220.8297560.8801870.869674-0.049350-0.1137260.025780-0.427815-0.173726-0.275476...-0.174815-0.278257-0.083275-0.4518090.1328020.0217730.0900150.2850520.4957530.981165
cli.10.9803790.9772250.8879440.877606-0.121569-0.0911860.096755-0.348133-0.129772-0.189519...-0.138355-0.204750-0.044520-0.3690660.2549110.0882430.0930920.2127210.5409810.838224
ili.10.9761710.9804730.8840200.873424-0.123680-0.1026450.096343-0.340973-0.128114-0.187173...-0.136904-0.202412-0.045430-0.3584470.2632780.0928250.0994120.2063780.5347510.829200
hh_cmnty_cli.10.8962110.8926670.9983560.996165-0.046423-0.0636190.089934-0.462807-0.235459-0.280262...-0.242962-0.288664-0.125902-0.4322340.3196970.1807030.1952940.3402230.5565470.879438
nohh_cmnty_cli.10.8851780.8812920.9951760.998259-0.056529-0.0558230.107979-0.455990-0.226870-0.268086...-0.234893-0.276769-0.119138-0.4234340.3288170.1864800.1952570.3361890.5289940.869278
wearing_mask.1-0.101056-0.102606-0.030237-0.0407380.998287-0.220397-0.732848-0.694338-0.789257-0.808963...-0.787873-0.806218-0.8927120.1308920.208011-0.0716890.4258300.8434690.342057-0.065600
travel_outside_state.1-0.097092-0.107662-0.069270-0.062039-0.2204420.9958380.2597480.2613350.2869210.352038...0.2873320.3426780.321376-0.202010-0.0038030.065990-0.372008-0.133520-0.090896-0.100407
work_outside_home.10.0870800.0859660.0749720.093529-0.7375540.2688640.9914710.6163940.7466800.697270...0.7377490.6986910.761755-0.1103370.0254300.077455-0.429387-0.652367-0.3252940.037930
shop.1-0.367850-0.361304-0.474799-0.467316-0.6886270.2524610.6385000.9912480.8202640.808526...0.8153170.8293630.7856310.131734-0.232298-0.029772-0.492524-0.864694-0.478978-0.412705
restaurant.1-0.147491-0.146353-0.250349-0.241687-0.7872450.2881600.7377250.8164140.9974960.877051...0.9974840.8765080.911182-0.046082-0.285147-0.070812-0.645411-0.832047-0.433497-0.159121
spent_time.1-0.216168-0.214354-0.297071-0.284398-0.8054680.3438540.7001460.8289920.8776600.995393...0.8761800.9953830.916829-0.039405-0.1750190.111992-0.513196-0.869427-0.523476-0.255714
large_event.1-0.051724-0.052961-0.130729-0.123252-0.8922670.3221490.7625920.7849960.9111430.916514...0.9120150.9173600.997449-0.137279-0.2262870.061477-0.559791-0.875258-0.376785-0.058079
public_transit.1-0.371063-0.360574-0.432765-0.4244450.132301-0.201241-0.1097270.131371-0.044942-0.039224...-0.047397-0.036646-0.1362000.991364-0.053367-0.1654620.005809-0.047158-0.140425-0.449079
anxious.10.2567120.2648720.3230530.3317910.217574-0.0110440.018079-0.246039-0.295416-0.189704...-0.309644-0.199497-0.260678-0.0531890.9809650.5676510.5199360.2646190.2019120.164537
depressed.10.0886760.0933710.1823830.188544-0.0693690.0617820.075357-0.034364-0.0738140.105809...-0.0660760.1170990.065712-0.1649730.5999520.9786230.601982-0.0545010.0241230.033149
felt_isolated.10.0994870.1054460.2010340.2008430.424822-0.374146-0.430562-0.493842-0.645507-0.514850...-0.638136-0.503093-0.5494720.0098420.5270400.6088960.9904460.4015430.1300050.081521
worried_become_ill.10.2233260.2177390.3475620.3430240.842499-0.134507-0.654251-0.865601-0.833903-0.869399...-0.832038-0.870442-0.874175-0.0452390.250985-0.0448860.4144440.9968780.4929980.264816
worried_finances.10.5433730.5378740.5573640.5295140.347359-0.094679-0.328919-0.482534-0.439702-0.529935...-0.444050-0.531072-0.389929-0.1417960.1689310.0265220.1382860.5017130.9948640.480958
tested_positive.10.8399290.8311290.8804160.870315-0.059477-0.1054670.031094-0.419104-0.165959-0.264309...-0.167639-0.268959-0.073982-0.4513970.1433950.0252720.0854170.2763380.4910430.991012
cli.20.9570590.9549960.8817680.872292-0.135146-0.0863320.104981-0.331428-0.116415-0.170275...-0.124823-0.185582-0.027097-0.3638150.2708110.0962700.0875260.1974070.5327700.835751
ili.20.9527070.9569790.8775500.867896-0.137841-0.0979910.104965-0.323789-0.114323-0.167358...-0.123005-0.182693-0.027895-0.3532090.2794850.1009970.0944630.1904360.5260260.826075
hh_cmnty_cli.20.8980670.8945640.9953960.993750-0.058149-0.0571640.099741-0.452086-0.223203-0.265245...-0.231610-0.275863-0.113619-0.4311420.3308820.1809630.1866530.3293300.5502900.878218
nohh_cmnty_cli.20.8871030.8832630.9917380.995093-0.067698-0.0492810.117226-0.445815-0.215113-0.253751...-0.223909-0.264597-0.107674-0.4218050.3390480.1855140.1865170.3260800.5225060.867535
wearing_mask.2-0.094664-0.096315-0.025367-0.0357590.995953-0.219423-0.729730-0.696457-0.788931-0.809003...-0.789539-0.808958-0.8957330.1286960.212752-0.0755990.4233250.8457210.343891-0.062037
travel_outside_state.2-0.097903-0.107903-0.069043-0.062137-0.2199160.9893100.2584300.2664380.2853800.352962...0.2868990.3478040.322521-0.199731-0.0079960.067252-0.372366-0.135255-0.089308-0.103868
work_outside_home.20.0859130.0847080.0699330.088394-0.7391120.2753480.9750170.5993630.7481850.700309...0.7436920.6952020.765953-0.1114770.0288030.080485-0.428880-0.652395-0.3330700.039304
shop.2-0.370197-0.363795-0.476538-0.469026-0.6854370.2496700.6409720.9778900.8180730.800586...0.8185090.8197550.7830570.132409-0.237570-0.031062-0.488979-0.862711-0.482649-0.415130
restaurant.2-0.151291-0.150141-0.253615-0.245265-0.7852810.2880980.7303490.8110550.9933580.875365...1.0000000.8765420.912564-0.046246-0.292246-0.067040-0.641984-0.831868-0.435929-0.160181
spent_time.2-0.222834-0.220942-0.300062-0.287482-0.8026590.3369370.7055330.8383580.8761070.986713...0.8765421.0000000.918931-0.037616-0.1802940.118125-0.507902-0.870630-0.524228-0.258956
large_event.2-0.060308-0.061298-0.136937-0.129474-0.8890210.3197360.7585750.7872370.9090890.912682...0.9125640.9189311.000000-0.135339-0.2385860.066021-0.554675-0.875487-0.380926-0.063709
public_transit.2-0.374071-0.363873-0.433276-0.4249960.133487-0.203611-0.1101760.130046-0.046081-0.040623...-0.046246-0.037616-0.1353391.000000-0.052253-0.1640790.009571-0.047068-0.142098-0.450436
anxious.20.2371350.2452280.3075810.3178360.2040310.0015920.018259-0.228007-0.278715-0.169965...-0.292246-0.180294-0.238586-0.0522531.0000000.5947970.5251710.2515090.1841260.152903
depressed.20.0814560.0862290.1814970.188467-0.0677200.0644250.075562-0.029168-0.0747270.105281...-0.0670400.1181250.066021-0.1640790.5947971.0000000.610310-0.0512460.0266210.029578
felt_isolated.20.0983450.1042500.2035770.2035990.427533-0.370776-0.430307-0.496368-0.648631-0.517139...-0.641984-0.507902-0.5546750.0095710.5251710.6103101.0000000.4079310.1324650.081174
worried_become_ill.20.2287500.2229090.3502550.3454480.840528-0.131961-0.652231-0.866789-0.832131-0.867460...-0.831868-0.870630-0.875487-0.0470680.251509-0.0512460.4079311.0000000.4958900.267610
worried_finances.20.5505640.5447760.5619420.5347110.340101-0.093096-0.317717-0.475304-0.430842-0.522985...-0.435929-0.524228-0.380926-0.1420980.1841260.0266210.1324650.4958901.0000000.485843
tested_positive.20.8385040.8305270.8797240.869938-0.069531-0.0973030.034865-0.410430-0.157945-0.252125...-0.160181-0.258956-0.063709-0.4504360.1529030.0295780.0811740.2676100.4858431.000000

54 rows × 54 columns

# 锁定上面相关性矩阵最后一列,也就是目标值列,每行是与其相关性大小
data_corr = data_tr.iloc[:, 40:].corr()
target_col = data_corr['tested_positive.2']
target_col
cli                       0.838504
ili                       0.830527
hh_cmnty_cli              0.879724
nohh_cmnty_cli            0.869938
wearing_mask             -0.069531
travel_outside_state     -0.097303
work_outside_home         0.034865
shop                     -0.410430
restaurant               -0.157945
spent_time               -0.252125
large_event              -0.052473
public_transit           -0.448360
anxious                   0.173295
depressed                 0.037689
felt_isolated             0.082182
worried_become_ill        0.262211
worried_finances          0.475462
tested_positive           0.981165
cli.1                     0.838224
ili.1                     0.829200
hh_cmnty_cli.1            0.879438
nohh_cmnty_cli.1          0.869278
wearing_mask.1           -0.065600
travel_outside_state.1   -0.100407
work_outside_home.1       0.037930
shop.1                   -0.412705
restaurant.1             -0.159121
spent_time.1             -0.255714
large_event.1            -0.058079
public_transit.1         -0.449079
anxious.1                 0.164537
depressed.1               0.033149
felt_isolated.1           0.081521
worried_become_ill.1      0.264816
worried_finances.1        0.480958
tested_positive.1         0.991012
cli.2                     0.835751
ili.2                     0.826075
hh_cmnty_cli.2            0.878218
nohh_cmnty_cli.2          0.867535
wearing_mask.2           -0.062037
travel_outside_state.2   -0.103868
work_outside_home.2       0.039304
shop.2                   -0.415130
restaurant.2             -0.160181
spent_time.2             -0.258956
large_event.2            -0.063709
public_transit.2         -0.450436
anxious.2                 0.152903
depressed.2               0.029578
felt_isolated.2           0.081174
worried_become_ill.2      0.267610
worried_finances.2        0.485843
tested_positive.2         1.000000
Name: tested_positive.2, dtype: float64
#在最后一列相关性数据中选择大于0.8的行,这个0.8是自己设的超参,大家可以根据实际情况调节
feature = target_col[target_col > 0.8]
feature
cli                  0.838504
ili                  0.830527
hh_cmnty_cli         0.879724
nohh_cmnty_cli       0.869938
tested_positive      0.981165
cli.1                0.838224
ili.1                0.829200
hh_cmnty_cli.1       0.879438
nohh_cmnty_cli.1     0.869278
tested_positive.1    0.991012
cli.2                0.835751
ili.2                0.826075
hh_cmnty_cli.2       0.878218
nohh_cmnty_cli.2     0.867535
tested_positive.2    1.000000
Name: tested_positive.2, dtype: float64
feature_cols = feature.index.tolist()  #将选择特征名称拿出来
feature_cols.pop() #去掉test_positive标签
pp.pprint(feature_cols) #得到每个需要特征名称列表
['cli',
 'ili',
 'hh_cmnty_cli',
 'nohh_cmnty_cli',
 'tested_positive',
 'cli.1',
 'ili.1',
 'hh_cmnty_cli.1',
 'nohh_cmnty_cli.1',
 'tested_positive.1',
 'cli.2',
 'ili.2',
 'hh_cmnty_cli.2',
 'nohh_cmnty_cli.2']
# 获取该特征对应列索引编号,后续就可以用feats + feats_selected作为特征值
feats_selected = [cols.index(col) for col in feature_cols]
feats_selected
[40, 41, 42, 43, 57, 58, 59, 60, 61, 75, 76, 77, 78, 79]

导入包

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# For data preprocess
import numpy as np
import csv
import os

# For plotting
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

myseed = 42069 # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(myseed)

导入工具

无需修改

def get_device():
''' Get device (if GPU is available, use GPU) '''
return 'cuda' if torch.cuda.is_available() else 'cpu'

def plot_learning_curve(loss_record, title=''):
''' Plot learning curve of your DNN (train & dev loss) '''
total_steps = len(loss_record['train'])
x_1 = range(total_steps)
x_2 = x_1[::len(loss_record['train']) // len(loss_record['dev'])]
figure(figsize=(6, 4))
plt.plot(x_1, loss_record['train'], c='tab:red', label='train')
plt.plot(x_2, loss_record['dev'], c='tab:cyan', label='dev')
plt.ylim(0.0, 5.)
plt.xlabel('Training steps')
plt.ylabel('MSE loss')
plt.title('Learning curve of {}'.format(title))
plt.legend()
plt.show()


def plot_pred(dv_set, model, device, lim=35., preds=None, targets=None):
''' Plot prediction of your DNN '''
if preds is None or targets is None:
model.eval()
preds, targets = [], []
for x, y in dv_set:
x, y = x.to(device), y.to(device)
with torch.no_grad():
pred = model(x)
preds.append(pred.detach().cpu())
targets.append(y.detach().cpu())
preds = torch.cat(preds, dim=0).numpy()
targets = torch.cat(targets, dim=0).numpy()

figure(figsize=(5, 5))
plt.scatter(targets, preds, c='r', alpha=0.5)
plt.plot([-0.2, lim], [-0.2, lim], c='b')
plt.xlim(-0.2, lim)
plt.ylim(-0.2, lim)
plt.xlabel('ground truth value')
plt.ylabel('predicted value')
plt.title('Ground Truth v.s. Prediction')
plt.show()

预处理

我们有三种数据集:

  • 训练集
  • 验证集
  • 测试集

数据集

COVID19Dataset完成以下操作:

  1. 读取.csv文件
  2. 提取特征
  3. 划分covid.train.csv为训练集和验证集
  4. 规范特征

提示: 完成以下操作有可以通过中等难度的分数线

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# For data preprocess
import numpy as np
import csv
import os
class COVID19Dataset(Dataset):
''' Dataset for loading and preprocessing the COVID19 dataset '''
def __init__(self, path, mu, std, mode='train', target_only=False):
# mu,std是自己加,baseline代码归一化有问题,重写归一化部分

# 初始化模型类别(训练、测试、验证),默认是train
self.mode = mode

# Read data into numpy arrays
with open(path, 'r') as fp:
data = list(csv.reader(fp))
# 去除id列
data = np.array(data[1:])[:, 1:].astype(float)

if not target_only:
feats = list(range(93))
else:
# TODO: Using 40 states & 2 tested_positive features (indices = 57 & 75)

# feats_selected是我们选择特征, 40代表是states特征
feats = list(range(40)) + feats_selected

#如果用只用两个特征,可以忽略前面数据分析过程,直接这样写
#feats = list(range(40)) + [57, 75]

if mode == 'test':
# Testing data
# data: 893 x 93 (40 states + day 1 (18) + day 2 (18) + day 3 (17))
data = data[:, feats]
self.data = torch.FloatTensor(data)
else:
# Training data (train/dev sets)
# data: 2700 x 94 (40 states + day 1 (18) + day 2 (18) + day 3 (18))
target = data[:, -1]
data = data[:, feats]

# Splitting training data into train & dev sets
# if mode == 'train':
# indices = [i for i in range(len(data)) if i % 10 != 0]
# elif mode == 'dev':
# indices = [i for i in range(len(data)) if i % 10 == 0]

# baseline代码中,划分训练集和测试集按照顺序选择数据,可能造成数据分布问题,改成随机选择
indices_tr, indices_dev = train_test_split([i for i in range(data.shape[0])], test_size = 0.3, random_state = 0)
if self.mode == 'train':
indices = indices_tr
elif self.mode == 'dev':
indices = indices_dev

# Convert data into PyTorch tensors
self.data = torch.FloatTensor(data[indices])
self.target = torch.FloatTensor(target[indices])



# Normalize features (you may remove this part to see what will happen)
# self.data[:, 40:] = \
# (self.data[:, 40:] - self.data[:, 40:].mean(dim=0, keepdim=True)) \
# / self.data[:, 40:].std(dim=0, keepdim=True)

# baseline这段代码数据归一化用的是当前数据归一化,事实上验证集上和测试集上归一化一般只能用过去数据即训练集上均值和方差进行归一化
# self.dim = self.data.shape[1]

# print('Finished reading the {} set of COVID19 Dataset ({} samples found, each dim = {})'
# .format(mode, len(self.data), self.dim))

# 如果是训练集,均值和方差用自己数据
if self.mode == "train":
self.mu = self.data[:, 40:].mean(dim=0, keepdim=True)
self.std = self.data[:, 40:].std(dim=0, keepdim=True)
else:
# 测试集和验证集,传进来的均值和方差是来自训练集保存,如何保存均值和方差,看数据dataload部分
self.mu = mu
self.std = std

self.data[:,40:] = (self.data[:, 40:] - self.mu) / self.std #归一化
self.dim = self.data.shape[1]

print('Finished reading the {} set of COVID19 Dataset ({} samples found, each dim = {})'
.format(mode, len(self.data), self.dim))



def __getitem__(self, index):
# Returns one sample at a time
if self.mode in ['train', 'dev']:
# For training
return self.data[index], self.target[index]
else:
# For testing (no target)
return self.data[index]

def __len__(self):
# Returns the size of the dataset
return len(self.data)
def prep_dataloader(path, mode, batch_size, n_jobs=0, target_only=False, mu=None, std=None): #训练集不需要传mu,std, 所以默认值设置为None
''' Generates a dataset, then is put into a dataloader. '''
dataset = COVID19Dataset(path, mu, std, mode=mode, target_only=target_only) # Construct dataset
# 如果是训练集,把训练集上均值和方差保存下来
if mode == 'train':
mu = dataset.mu
std = dataset.std
dataloader = DataLoader(
dataset, batch_size,
shuffle=(mode == 'train'), drop_last=False,
num_workers=n_jobs, pin_memory=True) # Construct dataloader
return dataloader, mu, std
class NeuralNet(nn.Module):
''' A simple fully-connected deep neural network '''
def __init__(self, input_dim):
super(NeuralNet, self).__init__()

# Define your neural network here
# TODO: How to modify this model to achieve better performance?
# 70是我调得最好的, 而且加层很容易过拟和
self.net = nn.Sequential(
nn.Linear(input_dim, 68),
nn.ReLU(),
nn.Linear(68,1)
)
# Mean squared error loss
self.criterion = nn.MSELoss(reduction='mean')

def forward(self, x):
''' Given input of size (batch_size x input_dim), compute output of the network '''
return self.net(x).squeeze(1)

def cal_loss(self, pred, target):
''' Calculate loss '''
# TODO: you may implement L2 regularization here
eps = 1e-6
l2_reg = 0
alpha = 0.0001
# 这段代码是l2正则,但是实际操作l2正则效果不好,大家也可以调,把下面这段代码取消注释就行
# for name, w in self.net.named_parameters():
# if 'weight' in name:
# l2_reg += alpha * torch.norm(w, p = 2).to(device)
return torch.sqrt(self.criterion(pred, target) + eps) + l2_reg
#lr_reg=0, 后面那段代码用的是均方根误差,均方根误差和kaggle评测指标一致,而且训练模型也更平稳
def train(tr_set, dv_set, model, config, device):
''' DNN training '''

n_epochs = config['n_epochs'] # Maximum number of epochs

# Setup optimizer
optimizer = getattr(torch.optim, config['optimizer'])(
model.parameters(), **config['optim_hparas'])

min_mse = 1000.
loss_record = {'train': [], 'dev': []} # for recording training loss
early_stop_cnt = 0
epoch = 0
while epoch < n_epochs:
model.train() # set model to training mode
for x, y in tr_set: # iterate through the dataloader
optimizer.zero_grad() # set gradient to zero
x, y = x.to(device), y.to(device) # move data to device (cpu/cuda)
pred = model(x) # forward pass (compute output)
mse_loss = model.cal_loss(pred, y) # compute loss
mse_loss.backward() # compute gradient (backpropagation)
optimizer.step() # update model with optimizer
loss_record['train'].append(mse_loss.detach().cpu().item())

# After each epoch, test your model on the validation (development) set.
dev_mse = dev(dv_set, model, device)
if dev_mse < min_mse:
# Save model if your model improved
min_mse = dev_mse
print('Saving model (epoch = {:4d}, loss = {:.4f})'
.format(epoch + 1, min_mse))
torch.save(model.state_dict(), config['save_path']) # Save model to specified path
early_stop_cnt = 0
else:
early_stop_cnt += 1

epoch += 1
loss_record['dev'].append(dev_mse)
if early_stop_cnt > config['early_stop']:
# Stop training if your model stops improving for "config['early_stop']" epochs.
break

print('Finished training after {} epochs'.format(epoch))
return min_mse, loss_record
def dev(dv_set, model, device):
model.eval() # set model to evalutation mode
total_loss = 0
for x, y in dv_set: # iterate through the dataloader
x, y = x.to(device), y.to(device) # move data to device (cpu/cuda)
with torch.no_grad(): # disable gradient calculation
pred = model(x) # forward pass (compute output)
mse_loss = model.cal_loss(pred, y) # compute loss
total_loss += mse_loss.detach().cpu().item() * len(x) # accumulate loss
total_loss = total_loss / len(dv_set.dataset) # compute averaged loss

return total_loss
def test(tt_set, model, device):
model.eval() # set model to evalutation mode
preds = []
for x in tt_set: # iterate through the dataloader
x = x.to(device) # move data to device (cpu/cuda)
with torch.no_grad(): # disable gradient calculation
pred = model(x) # forward pass (compute output)
preds.append(pred.detach().cpu()) # collect prediction
preds = torch.cat(preds, dim=0).numpy() # concatenate all predictions and convert to a numpy array
return preds
device = get_device()                 # get the current available device ('cpu' or 'cuda')
os.makedirs('models', exist_ok=True) # The trained model will be saved to ./models/
target_only = True # TODO: Using 40 states & 2 tested_positive features

# TODO: How to tune these hyper-parameters to improve your model's performance?
config = {
'n_epochs': 3000, # maximum number of epochs
'batch_size': 270, # mini-batch size for dataloader
'optimizer': 'SGD', # optimization algorithm (optimizer in torch.optim)
'optim_hparas': { # hyper-parameters for the optimizer (depends on which optimizer you are using)
'lr': 0.005, # learning rate of SGD
'momentum': 0.5 # momentum for SGD
},
'early_stop': 200, # early stopping epochs (the number epochs since your model's last improvement)
'save_path': 'models/model_select.path' # your model will be saved here
}
tr_set, tr_mu, tr_std = prep_dataloader(tr_path, 'train', config['batch_size'], target_only=target_only)
dv_set, mu_none, std_none = prep_dataloader(tr_path, 'dev', config['batch_size'], target_only=target_only, mu=tr_mu, std=tr_std)
tt_set, mu_none, std_none = prep_dataloader(tt_path, 'test', config['batch_size'], target_only=target_only, mu=tr_mu, std=tr_std)
Finished reading the train set of COVID19 Dataset (1890 samples found, each dim = 54)
Finished reading the dev set of COVID19 Dataset (810 samples found, each dim = 54)
Finished reading the test set of COVID19 Dataset (893 samples found, each dim = 54)
model = NeuralNet(tr_set.dataset.dim).to(device)  # Construct model and move to device
model_loss, model_loss_record = train(tr_set, dv_set, model, config, device)

Saving model (epoch =    1, loss = 17.9400)
Saving model (epoch =    2, loss = 17.7633)
Saving model (epoch =    3, loss = 17.5787)
Saving model (epoch =    4, loss = 17.3771)
……
Saving model (epoch =  581, loss = 0.9606)
Saving model (epoch =  594, loss = 0.9606)
Saving model (epoch =  598, loss = 0.9606)
Saving model (epoch =  599, loss = 0.9604)
Saving model (epoch =  600, loss = 0.9603)
Saving model (epoch =  621, loss = 0.9602)
Saving model (epoch =  706, loss = 0.9601)
Saving model (epoch =  741, loss = 0.9601)
Saving model (epoch =  781, loss = 0.9598)
Saving model (epoch =  786, loss = 0.9597)
Finished training after 987 epochs
plot_learning_curve(model_loss_record, title='deep model')

png

dev(dv_set, model, device)  #验证集损失 
0.9599974950154623
del model
model = NeuralNet(tr_set.dataset.dim).to(device)
ckpt = torch.load(config['save_path'], map_location='cpu') # Load your best model
model.load_state_dict(ckpt)
plot_pred(dv_set, model, device) # Show prediction on the validation set

png

def save_pred(preds, file):
''' Save predictions to specified file '''
print('Saving results to {}'.format(file))
with open(file, 'w') as fp:
writer = csv.writer(fp)
writer.writerow(['id', 'tested_positive'])
for i, p in enumerate(preds):
writer.writerow([i, p])
preds = test(tt_set, model, device) # predict COVID-19 cases with your model
save_pred(preds, 'commit.csv') # save prediction file to pred.csv
Saving results to commit.csv

提交结果: