gallivm: implement execution mask for scatter stores
authorBrian Paul <brianp@vmware.com>
Thu, 4 Nov 2010 16:01:28 +0000 (10:01 -0600)
committerBrian Paul <brianp@vmware.com>
Thu, 4 Nov 2010 16:01:28 +0000 (10:01 -0600)
src/gallium/auxiliary/gallivm/lp_bld_tgsi_soa.c

index e15baa3..a74cefd 100644 (file)
@@ -120,9 +120,12 @@ struct lp_build_tgsi_soa_context
 {
    struct lp_build_context base;
 
-   /* Builder for integer masks and indices */
+   /* Builder for vector integer masks and indices */
    struct lp_build_context uint_bld;
 
+   /* Builder for scalar elements of shader's data type (float) */
+   struct lp_build_context elem_bld;
+
    LLVMValueRef consts_ptr;
    const LLVMValueRef *pos;
    const LLVMValueRef (*inputs)[NUM_CHANNELS];
@@ -472,14 +475,26 @@ build_gather(struct lp_build_tgsi_soa_context *bld,
  * Scatter/store vector.
  */
 static void
-build_scatter(struct lp_build_tgsi_soa_context *bld,
-              LLVMValueRef base_ptr,
-              LLVMValueRef indexes,
-              LLVMValueRef values)
+emit_mask_scatter(struct lp_build_tgsi_soa_context *bld,
+                  LLVMValueRef base_ptr,
+                  LLVMValueRef indexes,
+                  LLVMValueRef values,
+                  struct lp_exec_mask *mask,
+                  LLVMValueRef pred)
 {
    LLVMBuilderRef builder = bld->base.builder;
    unsigned i;
 
+   /* Mix the predicate and execution mask */
+   if (mask->has_mask) {
+      if (pred) {
+         pred = LLVMBuildAnd(mask->bld->builder, pred, mask->exec_mask, "");
+      }
+      else {
+         pred = mask->exec_mask;
+      }
+   }
+
    /*
     * Loop over elements of index_vec, store scalar value.
     */
@@ -488,12 +503,22 @@ build_scatter(struct lp_build_tgsi_soa_context *bld,
       LLVMValueRef index = LLVMBuildExtractElement(builder, indexes, ii, "");
       LLVMValueRef scalar_ptr = LLVMBuildGEP(builder, base_ptr, &index, 1, "scatter_ptr");
       LLVMValueRef val = LLVMBuildExtractElement(builder, values, ii, "scatter_val");
+      LLVMValueRef scalar_pred = pred ?
+         LLVMBuildExtractElement(builder, pred, ii, "scatter_pred") : NULL;
 
       if (0)
          lp_build_printf(builder, "scatter %d: val %f at %d %p\n",
                          ii, val, index, scalar_ptr);
 
-      LLVMBuildStore(builder, val, scalar_ptr);
+      if (scalar_pred) {
+         LLVMValueRef real_val, dst_val;
+         dst_val = LLVMBuildLoad(builder, scalar_ptr, "");
+         real_val = lp_build_select(&bld->elem_bld, scalar_pred, val, dst_val);
+         LLVMBuildStore(builder, real_val, scalar_ptr);
+      }
+      else {
+         LLVMBuildStore(builder, val, scalar_ptr);
+      }
    }
 }
 
@@ -847,7 +872,8 @@ emit_store(
                                         float_ptr_type, "");
 
          /* Scatter store values into temp registers */
-         build_scatter(bld, temps_array, index_vec, value);
+         emit_mask_scatter(bld, temps_array, index_vec, value,
+                           &bld->exec_mask, pred);
       }
       else {
          LLVMValueRef temp_ptr = get_temp_ptr(bld, reg->Register.Index,
@@ -2192,6 +2218,7 @@ lp_build_tgsi_soa(LLVMBuilderRef builder,
    memset(&bld, 0, sizeof bld);
    lp_build_context_init(&bld.base, builder, type);
    lp_build_context_init(&bld.uint_bld, builder, lp_uint_type(type));
+   lp_build_context_init(&bld.elem_bld, builder, lp_elem_type(type));
    bld.mask = mask;
    bld.pos = pos;
    bld.inputs = inputs;