29 return CloneBase<GatherLayer>(graph,
GetName());
41 const unsigned int outputDim = paramsDim - 1 + indicesDim;
43 std::vector<unsigned int> dimSizes;
45 for (
unsigned int i = 0; i < indicesDim; ++i)
47 dimSizes.push_back(indices.
GetShape()[i]);
49 for (
unsigned int i = 1; i < paramsDim; ++i)
51 dimSizes.push_back(params.
GetShape()[i]);
56 ConditionalThrowIfNotEqual<LayerValidationException>(
57 "GatherLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
const TensorShape & GetShape() const
Copyright (c) 2020 ARM Limited.
GatherLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation &location) const
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
This layer represents a Gather operator.
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
Helper function to reduce duplication in *LayerCreateWorkload.
GatherLayer(const char *name)
Constructor to create a GatherLayer.
void Accept(ILayerVisitor &visitor) const override
Apply a visitor to this layer.
virtual std::unique_ptr< IWorkload > CreateGather(const GatherQueueDescriptor &descriptor, const WorkloadInfo &info) const
void Gather(const TensorInfo ¶msInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > ¶ms, const int32_t *indices, Encoder< float > &output)
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
virtual const TensorInfo & GetTensorInfo() const =0
const char * GetName() const override
Returns the name of the layer.
virtual void VisitGatherLayer(const IConnectableLayer *layer, const char *name=nullptr)=0
Function a Gather layer should call back to when its Accept(ILayerVisitor&) function is invoked...
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of GatherLayer.
const TensorInfo & GetTensorInfo() const override
unsigned int GetNumDimensions() const
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the Gather type.