From 0d611134fa54d6f69467625fe8467c12e019892a Mon Sep 17 00:00:00 2001 From: Andrew Tulloch Date: Mon, 5 Aug 2019 09:31:19 -0700 Subject: [PATCH] Metal reinterpret fix (#3706) --- src/codegen/codegen_metal.cc | 17 +++++++++++++++-- src/codegen/codegen_metal.h | 3 +++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc index 118904f..acd131f 100644 --- a/src/codegen/codegen_metal.cc +++ b/src/codegen/codegen_metal.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -246,6 +246,19 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI os << ')'; } +void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*) + if (op->is_intrinsic(Call::reinterpret)) { + // generate as_type(ARG) + os << "(as_type<"; + this->PrintType(op->type, os); + os << ">("; + this->PrintExpr(op->args[0], os); + os << "))"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + runtime::Module BuildMetal(Array funcs) { using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/codegen/codegen_metal.h b/src/codegen/codegen_metal.h index 5a94d8b..02d5451 100644 --- a/src/codegen/codegen_metal.h +++ b/src/codegen/codegen_metal.h @@ -53,6 +53,9 @@ class CodeGenMetal final : public CodeGenC { // overload visitor void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) + // overload visitor + void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) + private: int thread_index_bits_{32}; }; -- 2.7.4