TFTS: Better handling of exogenous features
authorAllen Lavoie <allenl@google.com>
Fri, 9 Feb 2018 18:13:05 +0000 (10:13 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Feb 2018 18:16:58 +0000 (10:16 -0800)
Adds (dummy) exogenous features to the LSTM model-building example, and adds some small methods needed to support that (fetching the shape of embedded exogenous features).

Also makes it more automatic to export a SavedModel with exogenous features (placeholder shapes will be inferred from the given FeatureColumns), which makes the LSTM example friendlier.

PiperOrigin-RevId: 185157085

tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv
tensorflow/contrib/timeseries/examples/lstm.py
tensorflow/contrib/timeseries/python/timeseries/estimators.py
tensorflow/contrib/timeseries/python/timeseries/model.py
tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py

index 02a60d1..b49a066 100644 (file)
-0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867
-1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303
-2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864
-3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426
-4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223
-5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842
-6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606
-7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347
-8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951
-9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228
-10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897
-11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634
-12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594
-13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394
-14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609
-15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449
-16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251
-17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382
-18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767
-19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713
-20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251
-21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811
-22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681
-23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735
-24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436
-25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899
-26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814
-27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727
-28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582
-29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555
-30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696
-31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548
-32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627
-33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104
-34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156
-35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459
-36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576
-37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584
-38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577
-39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467
-40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566
-41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909
-42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021
-43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831
-44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905
-45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271
-46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094
-47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554
-48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769
-49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606
-50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629
-51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199
-52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961
-53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122
-54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454
-55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301
-56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182
-57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365
-58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011
-59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449
-60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229
-61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259
-62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272
-63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989
-64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496
-65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376
-66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206
-67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502
-68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219
-69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125
-70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514
-71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166
-72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832
-73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913
-74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188
-75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388
-76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136
-77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766
-78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959
-79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083
-80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483
-81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656
-82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107
-83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991
-84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527
-85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649
-86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788
-87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289
-88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298
-89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873
-90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669
-91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462
-92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232
-93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225
-94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288
-95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086
-96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161
-97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227
-98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937
-99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724
+0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0.
+1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0.
+2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0.
+3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0.
+4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0.
+5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0.
+6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0.
+7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0.
+8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0.
+9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0.
+10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0.
+11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0.
+12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0.
+13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0.
+14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0.
+15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0.
+16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0.
+17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0.
+18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0.
+19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0.
+20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0.
+21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0.
+22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0.
+23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0.
+24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0.
+25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0.
+26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0.
+27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0.
+28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0.
+29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0.
+30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0.
+31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0.
+32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0.
+33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0.
+34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0.
+35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0.
+36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0.
+37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0.
+38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0.
+39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0.
+40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0.
+41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0.
+42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0.
+43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0.
+44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0.
+45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0.
+46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0.
+47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0.
+48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0.
+49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0.
+50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0.
+51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0.
+52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0.
+53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0.
+54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0.
+55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0.
+56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0.
+57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0.
+58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0.
+59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0.
+60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0.
+61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0.
+62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0.
+63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0.
+64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0.
+65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0.
+66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0.
+67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0.
+68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0.
+69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0.
+70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0.
+71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0.
+72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0.
+73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0.
+74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0.
+75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0.
+76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0.
+77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0.
+78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0.
+79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0.
+80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0.
+81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0.
+82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0.
+83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0.
+84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0.
+85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0.
+86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0.
+87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0.
+88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0.
+89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0.
+90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0.
+91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0.
+92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0.
+93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0.
+94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0.
+95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0.
+96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0.
+97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0.
+98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0.
+99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0.
index 630f4fc..f37cafc 100644 (file)
@@ -48,7 +48,8 @@ _DATA_FILE = path.join(_MODULE_PATH, "data/multivariate_periods.csv")
 class _LSTMModel(ts_model.SequentialTimeSeriesModel):
   """A time series model-building example using an RNNCell."""
 
-  def __init__(self, num_units, num_features, dtype=tf.float32):
+  def __init__(self, num_units, num_features, exogenous_feature_columns=None,
+               dtype=tf.float32):
     """Initialize/configure the model object.
 
     Note that we do not start graph building here. Rather, this object is a
@@ -58,6 +59,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
       num_units: The number of units in the model's LSTMCell.
       num_features: The dimensionality of the time series (features per
         timestep).
+      exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
+          objects representing features which are inputs to the model but are
+          not predicted by it. These must then be present for training,
+          evaluation, and prediction.
       dtype: The floating point data type to use.
     """
     super(_LSTMModel, self).__init__(
@@ -65,6 +70,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
         train_output_names=["mean"],
         predict_output_names=["mean"],
         num_features=num_features,
+        exogenous_feature_columns=exogenous_feature_columns,
         dtype=dtype)
     self._num_units = num_units
     # Filled in by initialize_graph()
@@ -104,6 +110,8 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
         tf.zeros([], dtype=tf.int64),
         # The previous observation or prediction.
         tf.zeros([self.num_features], dtype=self.dtype),
+        # The most recently seen exogenous features.
+        tf.zeros(self._get_exogenous_embedding_shape(), dtype=self.dtype),
         # The state of the RNNCell (batch dimension removed since this parent
         # class will broadcast).
         [tf.squeeze(state_element, axis=0)
@@ -131,7 +139,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
       loss (note that we could also return other measures of goodness of fit,
       although only "loss" will be optimized).
     """
-    state_from_time, prediction, lstm_state = state
+    state_from_time, prediction, exogenous, lstm_state = state
     with tf.control_dependencies(
         [tf.assert_equal(current_times, state_from_time)]):
       # Subtract the mean and divide by the variance of the series.  Slightly
@@ -143,16 +151,22 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
           (prediction - transformed_values) ** 2, axis=-1)
       # Keep track of the new observation in model state. It won't be run
       # through the LSTM until the next _imputation_step.
-      new_state_tuple = (current_times, transformed_values, lstm_state)
+      new_state_tuple = (current_times, transformed_values,
+                         exogenous, lstm_state)
     return (new_state_tuple, predictions)
 
   def _prediction_step(self, current_times, state):
     """Advance the RNN state using a previous observation or prediction."""
-    _, previous_observation_or_prediction, lstm_state = state
+    _, previous_observation_or_prediction, exogenous, lstm_state = state
+    # Update LSTM state based on the most recent exogenous and endogenous
+    # features.
+    inputs = tf.concat([previous_observation_or_prediction, exogenous],
+                       axis=-1)
     lstm_output, new_lstm_state = self._lstm_cell_run(
-        inputs=previous_observation_or_prediction, state=lstm_state)
+        inputs=inputs, state=lstm_state)
     next_prediction = self._predict_from_lstm_output(lstm_output)
-    new_state_tuple = (current_times, next_prediction, new_lstm_state)
+    new_state_tuple = (current_times, next_prediction,
+                       exogenous, new_lstm_state)
     return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
 
   def _imputation_step(self, current_times, state):
@@ -164,9 +178,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
 
   def _exogenous_input_step(
       self, current_times, current_exogenous_regressors, state):
-    """Update model state based on exogenous regressors."""
-    raise NotImplementedError(
-        "Exogenous inputs are not implemented for this example.")
+    """Save exogenous regressors in model state for use in _prediction_step."""
+    state_from_time, prediction, _, lstm_state = state
+    return (state_from_time, prediction,
+            current_exogenous_regressors, lstm_state)
 
 
 def train_and_predict(
@@ -174,24 +189,37 @@ def train_and_predict(
     export_directory=None):
   """Train and predict using a custom time series model."""
   # Construct an Estimator from our LSTM model.
+  exogenous_feature_columns = [
+      # Exogenous features are not part of the loss, but can inform
+      # predictions. In this example the features have no extra information, but
+      # are included as an API example.
+      tf.contrib.layers.real_valued_column(
+          "2d_exogenous_feature", dimension=2)]
   estimator = ts_estimators.TimeSeriesRegressor(
-      model=_LSTMModel(num_features=5, num_units=128),
+      model=_LSTMModel(num_features=5, num_units=128,
+                       exogenous_feature_columns=exogenous_feature_columns),
       optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config,
       # Set state to be saved across windows.
       state_manager=state_management.ChainingStateManager())
   reader = tf.contrib.timeseries.CSVReader(
       csv_file_name,
       column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,)
-                    + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5))
+                    + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5
+                    + ("2d_exogenous_feature",) * 2))
   train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
       reader, batch_size=4, window_size=32)
   estimator.train(input_fn=train_input_fn, steps=training_steps)
   evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
   evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
   # Predict starting after the evaluation
+  predict_exogenous_features = {
+      "2d_exogenous_feature": numpy.concatenate(
+          [numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])],
+          axis=-1)}
   (predictions,) = tuple(estimator.predict(
       input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
-          evaluation, steps=100)))
+          evaluation, steps=100,
+          exogenous_features=predict_exogenous_features)))
   times = evaluation["times"][0]
   observed = evaluation["observed"][0, :, :]
   predicted_mean = numpy.squeeze(numpy.concatenate(
@@ -204,7 +232,6 @@ def train_and_predict(
   input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
   export_location = estimator.export_savedmodel(
       export_directory, input_receiver_fn)
-
   # Predict using the SavedModel
   with tf.Graph().as_default():
     with tf.Session() as session:
@@ -213,7 +240,8 @@ def train_and_predict(
       saved_model_output = (
           tf.contrib.timeseries.saved_model_utils.predict_continuation(
               continue_from=evaluation, signatures=signatures,
-              session=session, steps=100))
+              session=session, steps=100,
+              exogenous_features=predict_exogenous_features))
       # The exported model gives the same results as the Estimator.predict()
       # call above.
       numpy.testing.assert_allclose(
index 3738dfa..f8355f3 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.layers.python.layers import feature_column
+
 from tensorflow.contrib.timeseries.python.timeseries import ar_model
 from tensorflow.contrib.timeseries.python.timeseries import feature_keys
 from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
@@ -72,15 +74,14 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
   # tf.Example containing all features (times, values, any exogenous features)
   # and serialized model state (possibly also as a tf.Example).
   def build_raw_serving_input_receiver_fn(self,
-                                          exogenous_features=None,
                                           default_batch_size=None,
                                           default_series_length=None):
     """Build an input_receiver_fn for export_savedmodel which accepts arrays.
 
+    Automatically creates placeholders for exogenous `FeatureColumn`s passed to
+    the model.
+
     Args:
-      exogenous_features: A dictionary mapping feature keys to exogenous
-        features (either Numpy arrays or Tensors). Used to determine the shapes
-        of placeholders for these features.
       default_batch_size: If specified, must be a scalar integer. Sets the batch
         size in the static shape information of all feature Tensors, which means
         only this batch size will be accepted by the exported model. If None
@@ -94,9 +95,6 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
       An input_receiver_fn which may be passed to the Estimator's
       export_savedmodel.
     """
-    if exogenous_features is None:
-      exogenous_features = {}
-
     def _serving_input_receiver_fn():
       """A receiver function to be passed to export_savedmodel."""
       placeholders = {}
@@ -119,14 +117,22 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
                   dtype=self._model.dtype),
               shape=(default_batch_size, default_series_length,
                      self._model.num_features)))
-      for feature_key, feature_value in exogenous_features.items():
-        value_tensor = ops.convert_to_tensor(feature_value)
-        value_tensor.get_shape().with_rank_at_least(2)
-        feature_shape = value_tensor.get_shape().as_list()
-        feature_shape[0] = default_batch_size
-        feature_shape[1] = default_series_length
+      with ops.Graph().as_default():
+        # Default placeholders have only an unknown batch dimension. Make them
+        # in a separate graph, then splice in the series length to the shapes
+        # and re-create them in the outer graph.
+        exogenous_feature_shapes = {
+            key: (value.get_shape(), value.dtype) for key, value
+            in feature_column.make_place_holder_tensors_for_base_features(
+                self._model.exogenous_feature_columns).items()}
+      for feature_key, (batch_only_feature_shape, value_dtype) in (
+          exogenous_feature_shapes.items()):
+        batch_only_feature_shape = batch_only_feature_shape.with_rank_at_least(
+            1).as_list()
+        feature_shape = ([default_batch_size, default_series_length]
+                         + batch_only_feature_shape[1:])
         placeholders[feature_key] = array_ops.placeholder(
-            dtype=value_tensor.dtype, name=feature_key, shape=feature_shape)
+            dtype=value_dtype, name=feature_key, shape=feature_shape)
       # Models may not know the shape of their state without creating some
       # variables/ops. Avoid polluting the default graph by making a new one. We
       # use only static metadata from the returned Tensors.
index b32b5c5..bac7d1e 100644 (file)
@@ -22,6 +22,7 @@ import abc
 import collections
 
 from tensorflow.contrib import layers
+from tensorflow.contrib.layers import feature_column
 
 from tensorflow.contrib.timeseries.python.timeseries import math_utils
 from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures
@@ -83,6 +84,11 @@ class TimeSeriesModel(object):
     self._stats_means = None
     self._stats_sigmas = None
 
+  @property
+  def exogenous_feature_columns(self):
+    """`FeatureColumn` objects for features which are not predicted."""
+    return self._exogenous_feature_columns
+
   # TODO(allenl): Move more of the generic machinery for generating and
   # predicting into TimeSeriesModel, and possibly share it between generate()
   # and predict()
@@ -250,6 +256,23 @@ class TimeSeriesModel(object):
     """
     pass
 
+  def _get_exogenous_embedding_shape(self):
+    """Computes the shape of the vector returned by _process_exogenous_features.
+
+    Returns:
+      The shape as a list. Does not include a batch dimension.
+    """
+    if not self._exogenous_feature_columns:
+      return (0,)
+    with ops.Graph().as_default():
+      placeholder_features = (
+          feature_column.make_place_holder_tensors_for_base_features(
+              self._exogenous_feature_columns))
+      embedded = layers.input_from_feature_columns(
+          columns_to_tensors=placeholder_features,
+          feature_columns=self._exogenous_feature_columns)
+      return embedded.get_shape().as_list()[1:]
+
   def _process_exogenous_features(self, times, features):
     """Create a single vector from exogenous features.
 
index 5980fc5..1fb4a3c 100644 (file)
@@ -187,9 +187,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
     estimator.train(combined_input_fn, steps=1)
     export_location = estimator.export_savedmodel(
         self.get_temp_dir(),
-        estimator.build_raw_serving_input_receiver_fn(
-            exogenous_features={
-                "exogenous": numpy.zeros((0, 0), dtype=numpy.float32)}))
+        estimator.build_raw_serving_input_receiver_fn())
     with ops.Graph().as_default() as graph:
       random_model.initialize_graph()
       with self.test_session(graph=graph) as session:
@@ -209,7 +207,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
             features={
                 feature_keys.FilteringFeatures.TIMES: [1, 2],
                 feature_keys.FilteringFeatures.VALUES: [1., 2.],
-                "exogenous": [-1., -2.]})
+                "exogenous": [[-1.], [-2.]]})
         second_split_filtering = saved_model_utils.filter_continuation(
             continue_from=first_split_filtering,
             signatures=signatures,
@@ -217,7 +215,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
             features={
                 feature_keys.FilteringFeatures.TIMES: [3, 4],
                 feature_keys.FilteringFeatures.VALUES: [3., 4.],
-                "exogenous": [-3., -4.]
+                "exogenous": [[-3.], [-4.]]
             })
         combined_filtering = saved_model_utils.filter_continuation(
             continue_from={
@@ -227,7 +225,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
             features={
                 feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
                 feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.],
-                "exogenous": [-1., -2., -3., -4.]
+                "exogenous": [[-1.], [-2.], [-3.], [-4.]]
             })
         split_predict = saved_model_utils.predict_continuation(
             continue_from=second_split_filtering,
@@ -235,14 +233,14 @@ class StateSpaceEquivalenceTests(test.TestCase):
             session=session,
             steps=1,
             exogenous_features={
-                "exogenous": [[-5.]]})
+                "exogenous": [[[-5.]]]})
         combined_predict = saved_model_utils.predict_continuation(
             continue_from=combined_filtering,
             signatures=signatures,
             session=session,
             steps=1,
             exogenous_features={
-                "exogenous": [[-5.]]})
+                "exogenous": [[[-5.]]]})
     for state_key, combined_state_value in combined_filtering.items():
       if state_key == feature_keys.FilteringResults.TIMES:
         continue