The first steps in new Gemm implementation were done: Now it's a standard NN operator with proper shape inference. Interpreter (an d maybe soft backend) support will be added in the next patch.
Pool and Conv operators are using similar input data that's why they were partially merged. As result they now produce proper output shape.
Signed-off-by: Andrew V. Tischenko a.tischenko@partner.samsung.com
"modelIR/operations/DepthwiseConv2DOp.cpp"
"modelIR/operations/FullyConnectedOp.cpp"
"modelIR/operations/GatherOp.cpp"
+ "modelIR/operations/GemmOp.cpp"
"modelIR/operations/PadOp.cpp"
"modelIR/operations/PoolOp.cpp"
"modelIR/operations/SqueezeOp.cpp"
#include "core/modelIR/operations/EluOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
#include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/PadOp.h"
#include "core/modelIR/operations/PoolOp.h"
#include "core/modelIR/operations/ReduceFOp.h"
dotBuilder.updateWithOp(&op, nodeInfo);
}
+void IrDotDumper::visit(ops::GemmOp& op) {
+ auto nodeInfo = DotIrNodeInfo().withType("Gemm", op.getName())
+ .withInShapes(getInputShapes(op))
+ .withOutShapes(getOutputShapes(op));
+ dotBuilder.updateWithOp(&op, nodeInfo);
+}
+
void IrDotDumper::visit(ops::SoftmaxOp& op) {
auto nodeInfo = DotIrNodeInfo().withType("Softmax", op.getName())
.withInShapes(getInputShapes(op))
#include "core/modelIR/operations/EluOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
#include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/PadOp.h"
#include "core/modelIR/operations/PoolOp.h"
#include "core/modelIR/operations/ReduceFOp.h"
--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/GemmOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void GemmOp::inferOutputShapes() {
+//Input tensor A: The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+//Input tensor B: The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+//Input tensor C: The shape of C should be unidirectional broadcastable to (M, N).
+//Output tensor Y: shape (M, N) or vector(N) if M == 1.
+ std::vector<int32_t > vector(2); // this vector will be used to create shapes of tensor A, B
+
+ // Flatten the shape by dim(0)
+ mir::Shape shape0 = {getInputShape(0).dim(0),
+ getInputShape(0).numElements() / getInputShape(0).dim(0)};
+ assert(shape0.rank() == 2);
+ Shape shape_a(2);
+ if (_transA) {
+ shape_a.dim(0) = shape0.dim(shape0.rank() - 1);
+ shape_a.dim(1) = shape0.dim(shape0.rank() - 2);
+ } else {
+ shape_a.dim(0) = shape0.dim(shape0.rank() - 2);
+ shape_a.dim(1) = shape0.dim(shape0.rank() - 1);
+ }
+
+ auto& shape1 = getInputShape(1);
+ // It must be a matrice
+ assert(shape1.rank() == 2);
+ Shape shape_b(2);
+
+ if (_transB) {
+ shape_b.dim(0) = shape1.dim(shape1.rank() - 1);
+ shape_b.dim(1) = shape1.dim(shape1.rank() - 2);
+ } else {
+ shape_b.dim(0) = shape1.dim(shape1.rank() - 2);
+ shape_b.dim(1) = shape1.dim(shape1.rank() - 1);
+ }
+
+ // Number of cols in tensor A must be equal to number of rows in tensor B
+ assert(shape_a.dim(1) == shape_b.dim(0));
+ Shape mult_a_b({shape_a.dim(0), shape_b.dim(1)});
+
+ Shape shape_c = getInputShape(2);
+
+ if (shape_c.rank() == 1){
+ assert(mult_a_b.dim(0) == 1);
+ assert(mult_a_b.dim(1) == shape_c.dim(0));
+ } else {
+ assert(shape_c.rank() == 2);
+ assert((mult_a_b.dim(0) == shape_c.dim(0)) && (mult_a_b.dim(1) == shape_c.dim(1)));
+ }
+ setOutputShape(0, mult_a_b);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
#define _NNC_BACKEND_INTERPRETER_CORE_DOTDUMPER_
#include "core/modelIR/Visitor.h"
-
#include "core/modelIR/ir_dot_builder.h"
namespace nnc
void visit(ops::EluOp& op) override;
void visit(ops::FullyConnectedOp& op) override;
void visit(ops::GatherOp& op) override;
+ void visit(ops::GemmOp& op) override;
void visit(ops::PadOp& op) override;
void visit(ops::PoolOp& op) override;
void visit(ops::ReduceFOp& op) override;
class ConstantOp : public Operation {
public:
- ConstantOp(const TensorVariant& value) : Operation(Type::constant, {}), _value(value) {
- setOutputShape(0, _value.getShape());
+ ConstantOp(const std::shared_ptr<mir::TensorVariant>& value) :
+ Operation(Type::constant, {}), _value(value) {
+ setOutputShape(0, _value->getShape());
}
- const TensorVariant& getValue() const { return _value; }
+ const std::shared_ptr<mir::TensorVariant>& getValue() const { return _value; }
private:
- TensorVariant _value;
+ const std::shared_ptr<mir::TensorVariant>& _value;
};
} // namespace ops
--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_IR_MODEL_GEMM_OP_H_
+#define _NNC_CORE_IR_MODEL_GEMM_OP_H_
+
+#include "core/modelIR/Operation.h"
+#include "core/modelIR/TensorVariant.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+class GemmOp : public Operation {
+public:
+ GemmOp(const IODescriptor a, const IODescriptor b, const IODescriptor c,
+ bool transA, bool transB, float alpha, float beta) :
+ Operation(Type::gemmOp, {a, b, c}),
+ _a(a), _b(b), _c(c), _transA(transA),_transB(transB), _alpha(alpha), _beta(beta) {
+ inferOutputShapes();
+ }
+
+ bool getTransA() {return _transA;}
+ bool getTransB() {return _transB;}
+ float getAlpha() {return _alpha;}
+ float getBeta() {return _beta;}
+
+private:
+ void inferOutputShapes();
+
+ const IODescriptor _a;
+ const IODescriptor _b;
+ const IODescriptor _c;
+ bool _transA;
+ bool _transB;
+ float _alpha;
+ float _beta;
+};
+} // namespace ops
+} // namespace mir
+} // namespace nnc
+
+#endif //_NNC_CORE_IR_MODEL_GEMM_OP_H_
HANDLE_OP(softmax, SoftmaxOp)
HANDLE_OP(pool, PoolOp)
HANDLE_OP(fullyConnected, FullyConnectedOp)
+HANDLE_OP(gemmOp, GemmOp)
HANDLE_OP(cappedReLU, CappedReluOp)
HANDLE_OP(biasAdd, BiasAddOp)
HANDLE_OP(variable, VariableOp)
void visit(ops::EluOp& op) override;
void visit(ops::FullyConnectedOp& op) override;
void visit(ops::GatherOp& op) override;
+ void visit(ops::GemmOp& op) override;
void visit(ops::PadOp& op) override;
void visit(ops::PoolOp& op) override;
void visit(ops::ReduceFOp& op) override;
#include "core/modelIR/operations/DropoutOp.h"
#include "core/modelIR/operations/ElementwiseOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/PoolOp.h"
#include "core/modelIR/operations/ReduceFOp.h"
#include "core/modelIR/operations/ReluOp.h"
runLayer(layer);
}
+void AclCppOpGenerator::visit(ops::GemmOp& op) {
+ assert(false);
+}
+
void AclCppOpGenerator::visit(ops::CappedReluOp& op) {
genActivation(op, "LU_BOUNDED_RELU", op.getCap());
}
void visit(mir::ops::EluOp& op) override;
void visit(mir::ops::FullyConnectedOp& op) override;
void visit(mir::ops::GatherOp& op) override;
+ void visit(mir::ops::GemmOp& op) override;
void visit(mir::ops::PadOp& op) override;
void visit(mir::ops::PoolOp& op) override;
void visit(mir::ops::ReduceFOp& op) override;
#include "passes/interpreter/Interpreter.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/SoftmaxOp.h"
#include "core/modelIR/operations/CappedReluOp.h"
#include "core/modelIR/operations/DepthwiseConv2DOp.h"
#include "ops/DeConv2D.h"
#include "ops/Depthwise_conv_2D.h"
#include "ops/FullyConnected.h"
+#include "ops/Gemm.h"
#include "ops/Pool.h"
#include "ops/Reshape.h"
#include "ops/Softmax.h"
}
void NNInterpreter::visit(ops::ConstantOp& op) {
- var(op.getId()) = {op.getValue()};
+ var(op.getId()) = {*op.getValue()};
}
std::vector<TensorVariant> &NNInterpreter::getResult(Operation* op) {
var(op.getId()) = FullyConnected<float>(input, op)();
}
+void NNInterpreter::visit(ops::GemmOp& op) {
+ mapByName(&op);
+ auto operand = op.getPrevNodes()[0];
+ TensorVariant input = var(operand.op->getId())[operand.index];
+ var(op.getId()) = Gemm<float>(input, op)();
+}
+
void NNInterpreter::visit(ops::CappedReluOp& op) {
mapByName(&op);
auto operand = op.getPrevNodes()[0];
--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Gemm.h"
+//Do not remove
+//Used to force compile Gemm.h
--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_GEMM_
+#define _NNC_CORE_BACKEND_INTERPRETER_GEMM_
+
+#include "core/modelIR/ShapeRange.h"
+#include "core/modelIR/operations/GemmOp.h"
+#include "OperationImpl.h"
+
+namespace nnc
+{
+template<typename T>
+class Gemm : public OperationImpl<T>
+{
+public:
+ Gemm(const mir::TensorVariant &_input, const mir::ops::GemmOp &_op) :
+ _op(_op), _input(_input) {}
+
+ std::vector<mir::TensorVariant> operator()() override {
+ mir::TensorVariant res = OperationImpl<T>::allocate_tensor(_op.getOutputShape(0));
+ return {res};
+ }
+
+private:
+ const mir::ops::GemmOp& _op;
+ const mir::Tensor<T> _input;
+};
+} // namespace nnc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_GEMM_
mir::Shape input_shape = ShapeHelper::createShape(onnx_tensor->dims(),
static_cast<size_t>(onnx_tensor->dims_size()));
_inputTensors[name] = createTensor(onnx_tensor, input_shape);
- auto constant = _graph->create<mir::ops::ConstantOp>(name, *_inputTensors[name].get());
+ auto constant = _graph->create<mir::ops::ConstantOp>(name, _inputTensors[name]);
_tensorNameToPrevMirOp[name] = constant;
constants.insert(constant);
} else {
// We're dealing with graph input
auto onnx_input_shape = input.type().tensor_type().shape();
- std::vector<int> shape_vector(onnx_input_shape.dim_size());
+ mir::Shape shape(4);
for (int i = 0; i < onnx_input_shape.dim_size(); i++) {
assert(onnx_input_shape.dim(i).has_dim_value());
- shape_vector[i] = onnx_input_shape.dim(i).dim_value();
+ shape.dim(i) = onnx_input_shape.dim(i).dim_value();
}
- mir::Shape shape(shape_vector);
- // TODO For now we only support convolutional networks. The input data have already been
- // transformed from ONNX NCHW format to ModelIR NHWC; reflect the changes in the IR.
- if (shape.rank() == 4)
- shape = mir::Shape{shape.dim(0), shape.dim(1), shape.dim(2), shape.dim(0)};
-
// TODO: Temporary solution!
auto node = _graph->create<mir::ops::VariableOp>(name, shape);
_tensorNameToPrevMirOp[name] = node;
_graph->setConstants(constants);
}
+static void dumpShape(const mir::Shape& shape) {
+ std::cout << "{";
+ for (int i = 0; i < shape.rank(); i++) {
+ std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", ");
+ }
+}
+
+void ONNXImporterImpl::dump(const std::vector<mir::Operation*>& ops, const onnx::NodeProto& onnx_node) {
+ for (auto op : ops) {
+ std::cout << onnx_node.op_type() << " '" << op->getName() << "' Input Shapes: ";
+ for (int i = 0; i < op->getNumInputs() ; i++) {
+ dumpShape(op->getInputShape(i));
+ }
+ std::cout << " Output Shapes: ";
+ for (int i = 0; i < op->getNumOutputs() ; i++) {
+ dumpShape(op->getOutputShape(i));
+ }
+ auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(onnx_node.op_type().c_str(), onnx_node.op_type().size());
+ switch (onnx_op_type->opCode) {
+ case ONNXOpCode::opConv: {
+ auto *conv = dynamic_cast<mir::ops::Conv2DOp *>(op);
+ std::cout << " Weights tensor shape ";
+ dumpShape(conv->getKernel().getShape());
+ std::cout << " Strides ";
+ dumpShape(conv->getStrides());
+ std::cout << " Padding before: (" << conv->getPaddingBefore()[0] << " " << conv->getPaddingBefore()[1] << ")";
+ std::cout << " After: (" << conv->getPaddingAfter()[0] << " " << conv->getPaddingAfter()[1] << ")";
+ break;
+ }
+ case ONNXOpCode::opGlobalAveragePool:
+ case ONNXOpCode::opAveragePool:
+ case ONNXOpCode::opMaxPool: {
+ auto *pool = dynamic_cast<mir::ops::PoolOp *>(op);
+ std::cout << " Kernel ";
+ dumpShape(pool->getWindowShape());
+ std::cout << " Strides ";
+ dumpShape(pool->getStrides());
+ std::cout << " Padding before: " << pool->getPaddingBefore()[0] << " " << pool->getPaddingBefore()[1];
+ std::cout << " After: " << pool->getPaddingAfter()[0] << " " << pool->getPaddingAfter()[1];
+ break;
+ }
+ default:
+ break;
+ }
+ std::cout << "\n";
+ }
+}
+
mir::Graph *ONNXImporterImpl::createIR() {
GOOGLE_PROTOBUF_VERIFY_VERSION;
assert (outputs.size());
// FIXME: it should be done properly via the given graph outputs
_graphOutputs.assign(outputs.begin(), outputs.end());
+ dump(outputs, onnx_node);
}
// set graph outputs
// TODO: it should be done with onnx graph outputs
void import() {};
mir::Graph *createIR() override;
+ void dump(const std::vector<mir::Operation*>& op, const onnx::NodeProto& onnx_node);
private:
void createGraphInputs();
#include "core/modelIR/operations/DepthwiseConv2DOp.h"
#include "core/modelIR/operations/DropoutOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/PoolOp.h"
#include "core/modelIR/operations/ReluOp.h"
#include "core/modelIR/operations/ReshapeOp.h"
std::shared_ptr<char> shared_buffer (new char[buffer_size], std::default_delete<char[]>());
memcpy(shared_buffer.get(), src_data, buffer_size);
Shape tensor_shape = Shape({1});
+ // FIXME: it has to be shared_ptr
auto mir_tensor = new mir::TensorVariant(tensor_shape, shared_buffer, type, element_size);
return mir_tensor;
}
+struct KernelStridesPadding {
+ Shape kernel_shape;
+ Shape strides_shape;
+ std::vector<int32_t> padding_before{0, 0};
+ std::vector<int32_t> padding_after{0, 0};
+};
+
+static void getKernelStridesPadding(const onnx::NodeProto &onnx_node, KernelStridesPadding &cdata) {
+ auto* kshape = findAttribute(onnx_node, "kernel_shape");
+ assert(kshape && kshape->ints_size());
+ auto* strides = findAttribute(onnx_node, "strides");
+ assert(strides && strides->ints_size());
+ auto* pads = findAttribute(onnx_node, "pads");
+
+ cdata.kernel_shape = ShapeHelper::createShape(kshape->ints(), kshape->ints_size());
+ cdata.strides_shape = ShapeHelper::createShape(strides->ints(), strides->ints_size());
+
+ if (pads) {
+ assert(pads->ints_size() >= 2);
+ cdata.padding_before[0] = pads->ints(0);
+ cdata.padding_before[1] = pads->ints(1);
+ // TODO: ONNX padding could be for the beginning and ending along each axis that's why we
+ // should select the interesting ones.
+ if (pads->ints_size() == 4) {
+ cdata.padding_after[0] = pads->ints(2);
+ cdata.padding_after[1] = pads->ints(3);
+ }
+ }
+};
+
std::vector<Operation*> ONNXOpCreator::convertConv2D(InputOps& inputs,
const onnx::NodeProto& onnx_node) {
assert(inputs.size() >= 2);
- auto* strides = findAttribute(onnx_node, "strides");
- assert(strides && strides->ints_size());
- Shape onnx_strides_shape = ShapeHelper::createShape(strides->ints(), strides->ints_size());
- // FIXME: it's a hack
- Shape strides_shape = {onnx_strides_shape.dim(0), onnx_strides_shape.dim(1), 1};
+ KernelStridesPadding cdata;
+ getKernelStridesPadding(onnx_node, cdata);
+ // FIXME: It can be non-constant value.
auto* in_weights = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]);
- assert(in_weights);
+ assert(in_weights && "Weights could be a constant tensor only");
auto in_weights_tensor = in_weights->getValue();
- // TODO: we don't support padding at the moment
- std::vector<int32_t> padding_before{0, 0};
- std::vector<int32_t> padding_after{0, 0};
- Operation* input_bias;
- if (inputs.size() > 2)
- input_bias = inputs[2];
+ // We should transpose ONNX MCHW to HWOI
+ auto transposed = transposeTensor<2, 3, 1, 0>(in_weights_tensor);
+
+ mir::ops::ConstantOp* input_bias = nullptr;
+ if (inputs.size() > 2) {
+ input_bias = dynamic_cast<mir::ops::ConstantOp*>(inputs[2]);
+ assert(input_bias && "1D optional bias could be a constant tensor only");
+ }
inputs.resize(1);
std::vector<Operation*> outputs;
- outputs = createOp<ops::Conv2DOp>(inputs[0]->getOutput(0), in_weights_tensor, strides_shape,
- padding_before, padding_after);
- // TODO: there could be bias tensor as inputs[2].
+ outputs = createOp<ops::Conv2DOp>(inputs[0]->getOutput(0), *transposed, cdata.strides_shape,
+ cdata.padding_before, cdata.padding_after);
+ if (input_bias)
+ outputs = createOp<ops::BiasAddOp>(outputs[0]->getOutput(0), *input_bias->getValue());
+
return outputs;
}
ops::PoolOp::PoolingType pool_type;
std::vector<Operation*> result;
- std::vector<int32_t> padding_before{0, 0};
- std::vector<int32_t> padding_after{0, 0};
+ KernelStridesPadding cdata;
switch (op_code) {
case ONNXOpCode::opGlobalAveragePool:
ops::PoolOp::PoolingType::AVG,
inputs[0]->getOutputShape(0), // kernel_shape
Shape({1, 1}), // strides_shape
- padding_before, padding_after,// no padding
+ cdata.padding_before, cdata.padding_after,
ops::PoolOp::BorderType::ZEROFILLED,
ops::PoolOp::RoundMode::floor);
case ONNXOpCode::opAveragePool:
assert(false);
}
// Proceed with Average or Max Pool
- auto* kshape = findAttribute(onnx_node, "kernel_shape");
- assert(kshape && kshape->ints_size());
- auto* strides = findAttribute(onnx_node, "strides");
- assert(strides && strides->ints_size());
- auto* pads = findAttribute(onnx_node, "pads");
-
- Shape kernel_shape = ShapeHelper::createShape(kshape->ints(), kshape->ints_size());
- Shape strides_shape = ShapeHelper::createShape(strides->ints(), strides->ints_size());
-
- if (pads) {
- assert(pads->ints_size() >= 2);
- padding_before[0] = pads->ints(0);
- padding_before[1] = pads->ints(1);
- // TODO: ONNX padding could be for the beginning and ending along each axis that's why we
- // should select the interesting ones.
- if (pads->ints_size() == 4) {
- padding_after[0] = pads->ints(2);
- padding_after[1] = pads->ints(3);
- }
+ getKernelStridesPadding(onnx_node, cdata);
- }
- result = createOp<ops::PoolOp>(inputs[0]->getOutput(0), pool_type, kernel_shape, strides_shape,
- padding_before, padding_after, border_type,
+ result = createOp<ops::PoolOp>(inputs[0]->getOutput(0), pool_type,
+ cdata.kernel_shape, cdata.strides_shape,
+ cdata.padding_before, cdata.padding_after,
+ border_type,
ops::PoolOp::RoundMode::floor);
return result;
}
}
std::vector<Operation*> ONNXOpCreator::convertDropout(InputOps& inputs,
- const onnx::NodeProto& onnx_node) {
+ const onnx::NodeProto& onnx_node) {
bool found;
float value;
std::tie(found, value) = getFloatAttribute(onnx_node, "ratio");
}
std::vector<Operation*> ONNXOpCreator::convertGemm(InputOps& inputs,
- const onnx::NodeProto& onnx_node) {
- // TODO: NIY
- return inputs;
+ const onnx::NodeProto& onnx_node) {
+ bool found;
+ int ivalue;
+ float fvalue;
+
+ std::tie (found, ivalue) = getIntAttribute(onnx_node, "transA");
+ bool transA = found ? ivalue : 0;
+ std::tie (found, ivalue) = getIntAttribute(onnx_node, "transB");
+ bool transB = found ? ivalue : 0;
+ std::tie (found, fvalue) = getIntAttribute(onnx_node, "alpha");
+ float alpha = found ? fvalue : 1.0;
+ std::tie (found, fvalue) = getIntAttribute(onnx_node, "beta");
+ float beta = found ? fvalue : 1.0;
+
+ // Flatten the shape by dim(0)
+ mir::Shape shape0 ({inputs[0]->getOutputShape(0).dim(0),
+ inputs[0]->getOutputShape(0).numElements() /
+ inputs[0]->getOutputShape(0).dim(0)});
+ auto reshape = createOp<ops::ReshapeOp>(inputs[0]->getOutput(0), shape0);
+
+ std::vector<IODescriptor> descriptors;
+ descriptors.push_back(reshape[0]->getOutput(0));
+ descriptors.push_back(inputs[1]->getOutput(0));
+ descriptors.push_back(inputs[2]->getOutput(0));
+
+ return createOp<ops::GemmOp>(reshape[0]->getOutput(0), inputs[1]->getOutput(0),
+ inputs[2]->getOutput(0), transA, transB, alpha, beta);
}
} // namespace nnc
#include "core/modelIR/operations/ElementwiseOp.h"
#include "core/modelIR/operations/EluOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/GatherOp.h"
#include "core/modelIR/operations/PadOp.h"
#include "core/modelIR/operations/PoolOp.h"
addOpDescr(&op, "fullConnect");
}
+void ModelAnalyzer::visit(ops::GemmOp& op) {
+ addOpDescr(&op, "gemm");
+}
+
void ModelAnalyzer::visit(ops::CappedReluOp& op) {
addOpDescr(&op, "cappedRelu");
}
void visit(mir::ops::EluOp& op) override;
void visit(mir::ops::FullyConnectedOp& op) override;
void visit(mir::ops::GatherOp& op) override;
+ void visit(mir::ops::GemmOp& op) override;
void visit(mir::ops::PadOp& op) override;
void visit(mir::ops::PoolOp& op) override;
void visit(mir::ops::ReduceFOp& op) override;
#include "core/modelIR/operations/SoftmaxOp.h"
#include "core/modelIR/operations/PoolOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/CappedReluOp.h"
#include "core/modelIR/operations/BiasAddOp.h"
#include "core/modelIR/operations/ReluOp.h"
serializeShape(op.getOutputShape(0));
}
+void Serializer::visit(ops::GemmOp& op) {
+ _curOp->_paramStartOffset = _buffer.size();
+}
+
void Serializer::visit(ops::CappedReluOp& op) {
_curOp->_paramStartOffset = _buffer.size();
serializeT<float>(op.getCap());
void Serializer::visit(ops::ConstantOp& op) {
_curOp->_paramStartOffset = _buffer.size();
- serializeTensor(op.getValue());
+ serializeTensor(*op.getValue());
}
void Serializer::visit(ops::ReluOp& op) {
void visit(mir::ops::EluOp& op) override;
void visit(mir::ops::FullyConnectedOp& op) override;
void visit(mir::ops::GatherOp& op) override;
+ void visit(mir::ops::GemmOp& op) override;
void visit(mir::ops::PadOp& op) override;
void visit(mir::ops::PoolOp& op) override;
void visit(mir::ops::ReduceFOp& op) override;
#include "core/modelIR/operations/VariableOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
#include "core/modelIR/operations/Conv2DOp.h"
#include "core/modelIR/operations/DepthwiseConv2DOp.h"
#include "core/modelIR/operations/PoolOp.h"