From 01b96c59f410b44a6279627529a643b1e4da4aa5 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 26 Feb 2018 14:00:07 -0800 Subject: [PATCH] TFTS: Switch to using core feature columns This fixes some shape issues that came up when using the tf.contrib.layers parsing functions. Adds a string -> embedding column API example to the LSTM example. PiperOrigin-RevId: 187076400 --- .../examples/data/multivariate_periods.csv | 200 ++++++++++----------- .../contrib/timeseries/examples/known_anomaly.py | 8 +- tensorflow/contrib/timeseries/examples/lstm.py | 26 ++- .../timeseries/python/timeseries/estimators.py | 53 +++--- .../contrib/timeseries/python/timeseries/model.py | 38 ++-- .../state_space_models/state_space_model.py | 10 +- 6 files changed, 177 insertions(+), 158 deletions(-) diff --git a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv index b49a066..9b15b4f 100644 --- a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv +++ b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv @@ -1,100 +1,100 @@ -0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0. -1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0. -2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0. -3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0. -4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0. -5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0. -6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0. -7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0. -8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0. -9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0. -10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0. -11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0. -12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0. -13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0. -14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0. -15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0. -16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0. -17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0. -18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0. -19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0. -20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0. -21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0. -22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0. -23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0. -24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0. -25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0. -26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0. -27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0. -28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0. -29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0. -30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0. -31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0. -32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0. -33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0. -34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0. -35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0. -36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0. -37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0. -38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0. -39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0. -40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0. -41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0. -42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0. -43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0. -44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0. -45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0. -46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0. -47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0. -48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0. -49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0. -50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0. -51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0. -52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0. -53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0. -54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0. -55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0. -56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0. -57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0. -58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0. -59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0. -60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0. -61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0. -62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0. -63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0. -64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0. -65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0. -66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0. -67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0. -68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0. -69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0. -70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0. -71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0. -72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0. -73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0. -74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0. -75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0. -76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0. -77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0. -78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0. -79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0. -80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0. -81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0. -82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0. -83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0. -84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0. -85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0. -86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0. -87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0. -88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0. -89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0. -90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0. -91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0. -92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0. -93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0. -94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0. -95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0. -96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0. -97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0. -98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0. -99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0. +0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0.,strkeya +1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0.,strkeyb +2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0.,strkey +3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0.,strkey +4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0.,strkey +5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0.,strkey +6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0.,strkey +7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0.,strkey +8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0.,strkey +9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0.,strkey +10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0.,strkeyc +11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0.,strkey +12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0.,strkey +13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0.,strkey +14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0.,strkey +15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0.,strkey +16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0.,strkey +17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0.,strkey +18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0.,strkey +19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0.,strkey +20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0.,strkey +21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0.,strkey +22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0.,strkey +23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0.,strkey +24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0.,strkey +25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0.,strkey +26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0.,strkey +27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0.,strkey +28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0.,strkey +29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0.,strkey +30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0.,strkey +31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0.,strkey +32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0.,strkey +33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0.,strkey +34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0.,strkey +35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0.,strkey +36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0.,strkey +37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0.,strkeyd +38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0.,strkey +39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0.,strkey +40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0.,strkey +41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0.,strkey +42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0.,strkey +43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0.,strkey +44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0.,strkey +45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0.,strkey +46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0.,strkey +47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0.,strkey +48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0.,strkey +49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0.,strkey +50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0.,strkey +51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0.,strkey +52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0.,strkey +53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0.,strkey +54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0.,strkey +55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0.,strkey +56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0.,strkey +57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0.,strkey +58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0.,strkey +59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0.,strkey +60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0.,strkey +61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0.,strkey +62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0.,strkey +63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0.,strkey +64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0.,strkey +65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0.,strkey +66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0.,strkey +67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0.,strkey +68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0.,strkey +69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0.,strkey +70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0.,strkey +71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0.,strkey +72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0.,strkey +73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0.,strkey +74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0.,strkey +75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0.,strkey +76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0.,strkey +77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0.,strkey +78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0.,strkey +79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0.,strkey +80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0.,strkey +81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0.,strkey +82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0.,strkey +83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0.,strkey +84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0.,strkey +85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0.,strkey +86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0.,strkey +87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0.,strkey +88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0.,strkey +89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0.,strkey +90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0.,strkey +91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0.,strkey +92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0.,strkey +93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0.,strkey +94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0.,strkey +95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0.,strkey +96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0.,strkey +97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0.,strkey +98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0.,strkey +99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0.,strkey diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py index 7659dd3..c08c0b0 100644 --- a/tensorflow/contrib/timeseries/examples/known_anomaly.py +++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py @@ -46,12 +46,12 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300): # Indicate the format of our exogenous feature, in this case a string # representing a boolean value. - string_feature = tf.contrib.layers.sparse_column_with_keys( - column_name="is_changepoint", keys=["no", "yes"]) + string_feature = tf.feature_column.categorical_column_with_vocabulary_list( + key="is_changepoint", vocabulary_list=["no", "yes"]) # Specify the way this feature is presented to the model, here using a one-hot # encoding. - one_hot_feature = tf.contrib.layers.one_hot_column( - sparse_id_column=string_feature) + one_hot_feature = tf.feature_column.indicator_column( + categorical_column=string_feature) estimator = tf.contrib.timeseries.StructuralEnsembleRegressor( periodicities=12, diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index f37cafc..2eee878 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -59,10 +59,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): num_units: The number of units in the model's LSTMCell. num_features: The dimensionality of the time series (features per timestep). - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects representing features which are inputs to the model but are - not predicted by it. These must then be present for training, - evaluation, and prediction. + exogenous_feature_columns: A list of `tf.feature_column`s representing + features which are inputs to the model but are not predicted by + it. These must then be present for training, evaluation, and + prediction. dtype: The floating point data type to use. """ super(_LSTMModel, self).__init__( @@ -189,12 +189,16 @@ def train_and_predict( export_directory=None): """Train and predict using a custom time series model.""" # Construct an Estimator from our LSTM model. + categorical_column = tf.feature_column.categorical_column_with_hash_bucket( + key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ # Exogenous features are not part of the loss, but can inform # predictions. In this example the features have no extra information, but # are included as an API example. - tf.contrib.layers.real_valued_column( - "2d_exogenous_feature", dimension=2)] + tf.feature_column.numeric_column( + "2d_exogenous_feature", shape=(2,)), + tf.feature_column.embedding_column( + categorical_column=categorical_column, dimension=10)] estimator = ts_estimators.TimeSeriesRegressor( model=_LSTMModel(num_features=5, num_units=128, exogenous_feature_columns=exogenous_feature_columns), @@ -205,7 +209,11 @@ def train_and_predict( csv_file_name, column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,) + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5 - + ("2d_exogenous_feature",) * 2)) + + ("2d_exogenous_feature",) * 2 + + ("categorical_exogenous_feature",)), + # Data types other than for `times` need to be specified if they aren't + # float32. In this case one of our exogenous features has string dtype. + column_dtypes=((tf.int64,) + (tf.float32,) * 7 + (tf.string,))) train_input_fn = tf.contrib.timeseries.RandomWindowInputFn( reader, batch_size=4, window_size=32) estimator.train(input_fn=train_input_fn, steps=training_steps) @@ -215,7 +223,9 @@ def train_and_predict( predict_exogenous_features = { "2d_exogenous_feature": numpy.concatenate( [numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])], - axis=-1)} + axis=-1), + "categorical_exogenous_feature": numpy.array( + ["strkey"] * 100)[None, :, None]} (predictions,) = tuple(estimator.predict( input_fn=tf.contrib.timeseries.predict_continuation_input_fn( evaluation, steps=100, diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index f8355f3..8d13343 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.layers.python.layers import feature_column - from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib @@ -31,10 +29,12 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filterin from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.export import export_lib +from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.training import training as train @@ -117,22 +117,29 @@ class TimeSeriesRegressor(estimator_lib.Estimator): dtype=self._model.dtype), shape=(default_batch_size, default_series_length, self._model.num_features))) - with ops.Graph().as_default(): - # Default placeholders have only an unknown batch dimension. Make them - # in a separate graph, then splice in the series length to the shapes - # and re-create them in the outer graph. - exogenous_feature_shapes = { - key: (value.get_shape(), value.dtype) for key, value - in feature_column.make_place_holder_tensors_for_base_features( - self._model.exogenous_feature_columns).items()} - for feature_key, (batch_only_feature_shape, value_dtype) in ( - exogenous_feature_shapes.items()): - batch_only_feature_shape = batch_only_feature_shape.with_rank_at_least( - 1).as_list() - feature_shape = ([default_batch_size, default_series_length] - + batch_only_feature_shape[1:]) - placeholders[feature_key] = array_ops.placeholder( - dtype=value_dtype, name=feature_key, shape=feature_shape) + if self._model.exogenous_feature_columns: + with ops.Graph().as_default(): + # Default placeholders have only an unknown batch dimension. Make them + # in a separate graph, then splice in the series length to the shapes + # and re-create them in the outer graph. + parsed_features = ( + feature_column.make_parse_example_spec( + self._model.exogenous_feature_columns)) + placeholder_features = parsing_ops.parse_example( + serialized=array_ops.placeholder( + shape=[None], dtype=dtypes.string), + features=parsed_features) + exogenous_feature_shapes = { + key: (value.get_shape(), value.dtype) for key, value + in placeholder_features.items()} + for feature_key, (batch_only_feature_shape, value_dtype) in ( + exogenous_feature_shapes.items()): + batch_only_feature_shape = ( + batch_only_feature_shape.with_rank_at_least(1).as_list()) + feature_shape = ([default_batch_size, default_series_length] + + batch_only_feature_shape[1:]) + placeholders[feature_key] = array_ops.placeholder( + dtype=value_dtype, name=feature_key, shape=feature_shape) # Models may not know the shape of their state without creating some # variables/ops. Avoid polluting the default graph by making a new one. We # use only static metadata from the returned Tensors. @@ -333,11 +340,11 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): determine the model size. Learning autoregressive coefficients typically requires more steps and a smaller step size than other components. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not part + of the series to be predicted. Passed to + `tf.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments, `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]), and diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index bac7d1e..7644764 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -21,18 +21,17 @@ from __future__ import print_function import abc import collections -from tensorflow.contrib import layers -from tensorflow.contrib.layers import feature_column - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures +from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope @@ -66,11 +65,11 @@ class TimeSeriesModel(object): Args: num_features: Number of features for the time series - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not + part of the series to be predicted. Passed to + `tf.feature_column.input_layer`. dtype: The floating point datatype to use. """ if exogenous_feature_columns: @@ -86,7 +85,7 @@ class TimeSeriesModel(object): @property def exogenous_feature_columns(self): - """`FeatureColumn` objects for features which are not predicted.""" + """`tf.feature_colum`s for features which are not predicted.""" return self._exogenous_feature_columns # TODO(allenl): Move more of the generic machinery for generating and @@ -265,11 +264,14 @@ class TimeSeriesModel(object): if not self._exogenous_feature_columns: return (0,) with ops.Graph().as_default(): - placeholder_features = ( - feature_column.make_place_holder_tensors_for_base_features( + parsed_features = ( + feature_column.make_parse_example_spec( self._exogenous_feature_columns)) - embedded = layers.input_from_feature_columns( - columns_to_tensors=placeholder_features, + placeholder_features = parsing_ops.parse_example( + serialized=array_ops.placeholder(shape=[None], dtype=dtypes.string), + features=parsed_features) + embedded = feature_column.input_layer( + features=placeholder_features, feature_columns=self._exogenous_feature_columns) return embedded.get_shape().as_list()[1:] @@ -308,13 +310,13 @@ class TimeSeriesModel(object): # Avoid shape warnings when embedding "scalar" exogenous features (those # with only batch and window dimensions); input_from_feature_columns # expects input ranks to match the embedded rank. - if tensor.get_shape().ndims == 1: + if tensor.get_shape().ndims == 1 and tensor.dtype != dtypes.string: exogenous_features_single_batch_dimension[name] = tensor[:, None] else: exogenous_features_single_batch_dimension[name] = tensor embedded_exogenous_features_single_batch_dimension = ( - layers.input_from_feature_columns( - columns_to_tensors=exogenous_features_single_batch_dimension, + feature_column.input_layer( + features=exogenous_features_single_batch_dimension, feature_columns=self._exogenous_feature_columns, trainable=True)) exogenous_regressors = array_ops.reshape( @@ -381,8 +383,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): may use _scale_back_data or _scale_back_variance to return predictions to the input scale. dtype: The floating point datatype to use. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects. See `TimeSeriesModel`. + exogenous_feature_columns: A list of `tf.feature_column`s objects. See + `TimeSeriesModel`. exogenous_update_condition: A function taking two Tensor arguments `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]) and returning a diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py index 6257002..951c654 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py @@ -112,11 +112,11 @@ class StateSpaceModelConfiguration( exogenous_noise_decreases: If True, exogenous regressors can "set" model state, decreasing uncertainty. If both this parameter and exogenous_noise_increases are False, exogenous regressors are ignored. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not part + of the series to be predicted. Passed to + `tf.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]) and returning a -- 2.7.4