Android: Add helper class for generating OpenGL shaders

This CL adds a helper class GlShaderBuilder to build an instances of
RendererCommon.GlDrawer that can accept multiple input sources
(OES, RGB, or YUV) using a generic fragment shader as input.

Bug: webrtc:9355
Change-Id: I14a0a280d2b6f838984f7b60897cc0c58e2a948a
Reviewed-on: https://webrtc-review.googlesource.com/80940
Commit-Queue: Magnus Jedvert <magjed@webrtc.org>
Reviewed-by: Sami Kalliomäki <sakal@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23622}
diff --git a/BUILD.gn b/BUILD.gn
index e997937..7566f2a 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -584,6 +584,7 @@
         "examples/androidjunit/src/org/appspot/apprtc/BluetoothManagerTest.java",
         "examples/androidjunit/src/org/appspot/apprtc/DirectRTCClientTest.java",
         "examples/androidjunit/src/org/appspot/apprtc/TCPChannelClientTest.java",
+        "sdk/android/tests/src/org/webrtc/GlGenericDrawerTest.java",
         "sdk/android/tests/src/org/webrtc/CameraEnumerationTest.java",
         "sdk/android/tests/src/org/webrtc/ScalingSettingsTest.java",
       ]
diff --git a/sdk/android/BUILD.gn b/sdk/android/BUILD.gn
index 1ec1d68..6d23594 100644
--- a/sdk/android/BUILD.gn
+++ b/sdk/android/BUILD.gn
@@ -757,6 +757,8 @@
     "api/org/webrtc/SurfaceTextureHelper.java",
     "api/org/webrtc/TextureBufferImpl.java",
     "api/org/webrtc/YuvConverter.java",
+    "api/org/webrtc/VideoFrameDrawer.java",
+    "src/java/org/webrtc/GlGenericDrawer.java",
     "api/org/webrtc/YuvHelper.java",
     "src/java/org/webrtc/EglBase10.java",
     "src/java/org/webrtc/EglBase14.java",
@@ -823,7 +825,6 @@
     "api/org/webrtc/GlRectDrawer.java",
     "api/org/webrtc/VideoDecoderFallback.java",
     "api/org/webrtc/VideoEncoderFallback.java",
-    "api/org/webrtc/VideoFrameDrawer.java",
     "src/java/org/webrtc/NativeCapturerObserver.java",
     "src/java/org/webrtc/NV21Buffer.java",
     "src/java/org/webrtc/VideoDecoderWrapper.java",
diff --git a/sdk/android/api/org/webrtc/GlRectDrawer.java b/sdk/android/api/org/webrtc/GlRectDrawer.java
index 44787f7..d1fbd1b 100644
--- a/sdk/android/api/org/webrtc/GlRectDrawer.java
+++ b/sdk/android/api/org/webrtc/GlRectDrawer.java
@@ -10,201 +10,22 @@
 
 package org.webrtc;
 
-import android.opengl.GLES11Ext;
-import android.opengl.GLES20;
-import java.nio.FloatBuffer;
-import java.util.IdentityHashMap;
-import java.util.Map;
-
-/**
- * Helper class to draw an opaque quad on the target viewport location. Rotation, mirror, and
- * cropping is specified using a 4x4 texture coordinate transform matrix. The frame input can either
- * be an OES texture or YUV textures in I420 format. The GL state must be preserved between draw
- * calls, this is intentional to maximize performance. The function release() must be called
- * manually to free the resources held by this object.
- */
-public class GlRectDrawer implements RendererCommon.GlDrawer {
-  // clang-format off
-  // Simple vertex shader, used for both YUV and OES.
-  private static final String VERTEX_SHADER_STRING =
-        "varying vec2 interp_tc;\n"
-      + "attribute vec4 in_pos;\n"
-      + "attribute vec4 in_tc;\n"
-      + "\n"
-      + "uniform mat4 texMatrix;\n"
-      + "\n"
-      + "void main() {\n"
-      + "    gl_Position = in_pos;\n"
-      + "    interp_tc = (texMatrix * in_tc).xy;\n"
+/** Simplest possible GL shader that just draws frames as opaque quads. */
+public class GlRectDrawer extends GlGenericDrawer {
+  private static final String FRAGMENT_SHADER = "void main() {\n"
+      + "  gl_FragColor = sample(tc);\n"
       + "}\n";
 
-  private static final String YUV_FRAGMENT_SHADER_STRING =
-        "precision mediump float;\n"
-      + "varying vec2 interp_tc;\n"
-      + "\n"
-      + "uniform sampler2D y_tex;\n"
-      + "uniform sampler2D u_tex;\n"
-      + "uniform sampler2D v_tex;\n"
-      + "\n"
-      + "void main() {\n"
-      // CSC according to http://www.fourcc.org/fccyvrgb.php
-      + "  float y = texture2D(y_tex, interp_tc).r;\n"
-      + "  float u = texture2D(u_tex, interp_tc).r - 0.5;\n"
-      + "  float v = texture2D(v_tex, interp_tc).r - 0.5;\n"
-      + "  gl_FragColor = vec4(y + 1.403 * v, "
-      + "                      y - 0.344 * u - 0.714 * v, "
-      + "                      y + 1.77 * u, 1);\n"
-      + "}\n";
+  private static class ShaderCallbacks implements GlGenericDrawer.ShaderCallbacks {
+    @Override
+    public void onNewShader(GlShader shader) {}
 
-  private static final String RGB_FRAGMENT_SHADER_STRING =
-        "precision mediump float;\n"
-      + "varying vec2 interp_tc;\n"
-      + "\n"
-      + "uniform sampler2D rgb_tex;\n"
-      + "\n"
-      + "void main() {\n"
-      + "  gl_FragColor = texture2D(rgb_tex, interp_tc);\n"
-      + "}\n";
-
-  private static final String OES_FRAGMENT_SHADER_STRING =
-        "#extension GL_OES_EGL_image_external : require\n"
-      + "precision mediump float;\n"
-      + "varying vec2 interp_tc;\n"
-      + "\n"
-      + "uniform samplerExternalOES oes_tex;\n"
-      + "\n"
-      + "void main() {\n"
-      + "  gl_FragColor = texture2D(oes_tex, interp_tc);\n"
-      + "}\n";
-  // clang-format on
-
-  // Vertex coordinates in Normalized Device Coordinates, i.e. (-1, -1) is bottom-left and (1, 1) is
-  // top-right.
-  private static final FloatBuffer FULL_RECTANGLE_BUF = GlUtil.createFloatBuffer(new float[] {
-      -1.0f, -1.0f, // Bottom left.
-      1.0f, -1.0f, // Bottom right.
-      -1.0f, 1.0f, // Top left.
-      1.0f, 1.0f, // Top right.
-  });
-
-  // Texture coordinates - (0, 0) is bottom-left and (1, 1) is top-right.
-  private static final FloatBuffer FULL_RECTANGLE_TEX_BUF = GlUtil.createFloatBuffer(new float[] {
-      0.0f, 0.0f, // Bottom left.
-      1.0f, 0.0f, // Bottom right.
-      0.0f, 1.0f, // Top left.
-      1.0f, 1.0f // Top right.
-  });
-
-  private static class Shader {
-    public final GlShader glShader;
-    public final int texMatrixLocation;
-
-    public Shader(String fragmentShader) {
-      this.glShader = new GlShader(VERTEX_SHADER_STRING, fragmentShader);
-      this.texMatrixLocation = glShader.getUniformLocation("texMatrix");
-    }
+    @Override
+    public void onPrepareShader(GlShader shader, float[] texMatrix, int frameWidth, int frameHeight,
+        int viewportWidth, int viewportHeight) {}
   }
 
-  // The keys are one of the fragments shaders above.
-  private final Map<String, Shader> shaders = new IdentityHashMap<String, Shader>();
-
-  /**
-   * Draw an OES texture frame with specified texture transformation matrix. Required resources are
-   * allocated at the first call to this function.
-   */
-  @Override
-  public void drawOes(int oesTextureId, float[] texMatrix, int frameWidth, int frameHeight,
-      int viewportX, int viewportY, int viewportWidth, int viewportHeight) {
-    prepareShader(OES_FRAGMENT_SHADER_STRING, texMatrix);
-    GLES20.glActiveTexture(GLES20.GL_TEXTURE0);
-    // updateTexImage() may be called from another thread in another EGL context, so we need to
-    // bind/unbind the texture in each draw call so that GLES understads it's a new texture.
-    GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, oesTextureId);
-    drawRectangle(viewportX, viewportY, viewportWidth, viewportHeight);
-    GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
-  }
-
-  /**
-   * Draw a RGB(A) texture frame with specified texture transformation matrix. Required resources
-   * are allocated at the first call to this function.
-   */
-  @Override
-  public void drawRgb(int textureId, float[] texMatrix, int frameWidth, int frameHeight,
-      int viewportX, int viewportY, int viewportWidth, int viewportHeight) {
-    prepareShader(RGB_FRAGMENT_SHADER_STRING, texMatrix);
-    GLES20.glActiveTexture(GLES20.GL_TEXTURE0);
-    GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, textureId);
-    drawRectangle(viewportX, viewportY, viewportWidth, viewportHeight);
-    // Unbind the texture as a precaution.
-    GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, 0);
-  }
-
-  /**
-   * Draw a YUV frame with specified texture transformation matrix. Required resources are
-   * allocated at the first call to this function.
-   */
-  @Override
-  public void drawYuv(int[] yuvTextures, float[] texMatrix, int frameWidth, int frameHeight,
-      int viewportX, int viewportY, int viewportWidth, int viewportHeight) {
-    prepareShader(YUV_FRAGMENT_SHADER_STRING, texMatrix);
-    // Bind the textures.
-    for (int i = 0; i < 3; ++i) {
-      GLES20.glActiveTexture(GLES20.GL_TEXTURE0 + i);
-      GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, yuvTextures[i]);
-    }
-    drawRectangle(viewportX, viewportY, viewportWidth, viewportHeight);
-    // Unbind the textures as a precaution..
-    for (int i = 0; i < 3; ++i) {
-      GLES20.glActiveTexture(GLES20.GL_TEXTURE0 + i);
-      GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, 0);
-    }
-  }
-
-  private void drawRectangle(int x, int y, int width, int height) {
-    // Draw quad.
-    GLES20.glViewport(x, y, width, height);
-    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
-  }
-
-  private void prepareShader(String fragmentShader, float[] texMatrix) {
-    final Shader shader;
-    if (shaders.containsKey(fragmentShader)) {
-      shader = shaders.get(fragmentShader);
-    } else {
-      // Lazy allocation.
-      shader = new Shader(fragmentShader);
-      shaders.put(fragmentShader, shader);
-      shader.glShader.useProgram();
-      // Initialize fragment shader uniform values.
-      if (YUV_FRAGMENT_SHADER_STRING.equals(fragmentShader)) {
-        GLES20.glUniform1i(shader.glShader.getUniformLocation("y_tex"), 0);
-        GLES20.glUniform1i(shader.glShader.getUniformLocation("u_tex"), 1);
-        GLES20.glUniform1i(shader.glShader.getUniformLocation("v_tex"), 2);
-      } else if (RGB_FRAGMENT_SHADER_STRING.equals(fragmentShader)) {
-        GLES20.glUniform1i(shader.glShader.getUniformLocation("rgb_tex"), 0);
-      } else if (OES_FRAGMENT_SHADER_STRING.equals(fragmentShader)) {
-        GLES20.glUniform1i(shader.glShader.getUniformLocation("oes_tex"), 0);
-      } else {
-        throw new IllegalStateException("Unknown fragment shader: " + fragmentShader);
-      }
-      GlUtil.checkNoGLES2Error("Initialize fragment shader uniform values.");
-      // Initialize vertex shader attributes.
-      shader.glShader.setVertexAttribArray("in_pos", 2, FULL_RECTANGLE_BUF);
-      shader.glShader.setVertexAttribArray("in_tc", 2, FULL_RECTANGLE_TEX_BUF);
-    }
-    shader.glShader.useProgram();
-    // Copy the texture transformation matrix over.
-    GLES20.glUniformMatrix4fv(shader.texMatrixLocation, 1, false, texMatrix, 0);
-  }
-
-  /**
-   * Release all GLES resources. This needs to be done manually, otherwise the resources are leaked.
-   */
-  @Override
-  public void release() {
-    for (Shader shader : shaders.values()) {
-      shader.glShader.release();
-    }
-    shaders.clear();
+  public GlRectDrawer() {
+    super(FRAGMENT_SHADER, new ShaderCallbacks());
   }
 }
diff --git a/sdk/android/api/org/webrtc/RendererCommon.java b/sdk/android/api/org/webrtc/RendererCommon.java
index 0554f11..ee73430 100644
--- a/sdk/android/api/org/webrtc/RendererCommon.java
+++ b/sdk/android/api/org/webrtc/RendererCommon.java
@@ -31,7 +31,12 @@
     public void onFrameResolutionChanged(int videoWidth, int videoHeight, int rotation);
   }
 
-  /** Interface for rendering frames on an EGLSurface. */
+  /**
+   * Interface for rendering frames on an EGLSurface with specified viewport location. Rotation,
+   * mirror, and cropping is specified using a 4x4 texture coordinate transform matrix. The frame
+   * input can either be an OES texture, RGB texture, or YUV textures in I420 format. The function
+   * release() must be called manually to free the resources held by this object.
+   */
   public static interface GlDrawer {
     /**
      * Functions for drawing frames with different sources. The rendering surface target is
diff --git a/sdk/android/api/org/webrtc/YuvConverter.java b/sdk/android/api/org/webrtc/YuvConverter.java
index f7922d6..be8d43e 100644
--- a/sdk/android/api/org/webrtc/YuvConverter.java
+++ b/sdk/android/api/org/webrtc/YuvConverter.java
@@ -10,10 +10,9 @@
 
 package org.webrtc;
 
-import android.opengl.GLES11Ext;
+import android.graphics.Matrix;
 import android.opengl.GLES20;
 import java.nio.ByteBuffer;
-import java.nio.FloatBuffer;
 import org.webrtc.VideoFrame.I420Buffer;
 import org.webrtc.VideoFrame.TextureBuffer;
 
@@ -22,45 +21,10 @@
  * should only be operated from a single thread with an active EGL context.
  */
 public class YuvConverter {
-  // Vertex coordinates in Normalized Device Coordinates, i.e.
-  // (-1, -1) is bottom-left and (1, 1) is top-right.
-  private static final FloatBuffer DEVICE_RECTANGLE = GlUtil.createFloatBuffer(new float[] {
-      -1.0f, -1.0f, // Bottom left.
-      1.0f, -1.0f, // Bottom right.
-      -1.0f, 1.0f, // Top left.
-      1.0f, 1.0f, // Top right.
-  });
-
-  // Texture coordinates - (0, 0) is bottom-left and (1, 1) is top-right.
-  private static final FloatBuffer TEXTURE_RECTANGLE = GlUtil.createFloatBuffer(new float[] {
-      0.0f, 0.0f, // Bottom left.
-      1.0f, 0.0f, // Bottom right.
-      0.0f, 1.0f, // Top left.
-      1.0f, 1.0f // Top right.
-  });
-
-  // clang-format off
-  private static final String VERTEX_SHADER =
-        "varying vec2 interp_tc;\n"
-      + "attribute vec4 in_pos;\n"
-      + "attribute vec4 in_tc;\n"
-      + "\n"
-      + "uniform mat4 texMatrix;\n"
-      + "\n"
-      + "void main() {\n"
-      + "    gl_Position = in_pos;\n"
-      + "    interp_tc = (texMatrix * in_tc).xy;\n"
-      + "}\n";
-
-  private static final String OES_FRAGMENT_SHADER =
-        "#extension GL_OES_EGL_image_external : require\n"
-      + "precision mediump float;\n"
-      + "varying vec2 interp_tc;\n"
-      + "\n"
-      + "uniform samplerExternalOES tex;\n"
+  private static final String FRAGMENT_SHADER =
       // Difference in texture coordinate corresponding to one
       // sub-pixel in the x direction.
-      + "uniform vec2 xUnit;\n"
+      "uniform vec2 xUnit;\n"
       // Color conversion coefficients, including constant term
       + "uniform vec4 coeffs;\n"
       + "\n"
@@ -72,52 +36,66 @@
       // try to do it as a vec3 x mat3x4, followed by an add in of a
       // constant vector.
       + "  gl_FragColor.r = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc - 1.5 * xUnit).rgb);\n"
+      + "      sample(tc - 1.5 * xUnit).rgb);\n"
       + "  gl_FragColor.g = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc - 0.5 * xUnit).rgb);\n"
+      + "      sample(tc - 0.5 * xUnit).rgb);\n"
       + "  gl_FragColor.b = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc + 0.5 * xUnit).rgb);\n"
+      + "      sample(tc + 0.5 * xUnit).rgb);\n"
       + "  gl_FragColor.a = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc + 1.5 * xUnit).rgb);\n"
+      + "      sample(tc + 1.5 * xUnit).rgb);\n"
       + "}\n";
 
-  private static final String RGB_FRAGMENT_SHADER =
-        "precision mediump float;\n"
-      + "varying vec2 interp_tc;\n"
-      + "\n"
-      + "uniform sampler2D tex;\n"
-      // Difference in texture coordinate corresponding to one
-      // sub-pixel in the x direction.
-      + "uniform vec2 xUnit;\n"
-      // Color conversion coefficients, including constant term
-      + "uniform vec4 coeffs;\n"
-      + "\n"
-      + "void main() {\n"
-      // Since the alpha read from the texture is always 1, this could
-      // be written as a mat4 x vec4 multiply. However, that seems to
-      // give a worse framerate, possibly because the additional
-      // multiplies by 1.0 consume resources. TODO(nisse): Could also
-      // try to do it as a vec3 x mat3x4, followed by an add in of a
-      // constant vector.
-      + "  gl_FragColor.r = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc - 1.5 * xUnit).rgb);\n"
-      + "  gl_FragColor.g = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc - 0.5 * xUnit).rgb);\n"
-      + "  gl_FragColor.b = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc + 0.5 * xUnit).rgb);\n"
-      + "  gl_FragColor.a = coeffs.a + dot(coeffs.rgb,\n"
-      + "      texture2D(tex, interp_tc + 1.5 * xUnit).rgb);\n"
-      + "}\n";
-  // clang-format on
+  private static class ShaderCallbacks implements GlGenericDrawer.ShaderCallbacks {
+    // Y'UV444 to RGB888, see https://en.wikipedia.org/wiki/YUV#Y.27UV444_to_RGB888_conversion. We
+    // use the ITU-R coefficients for U and V.
+    private static final float[] yCoeffs = new float[] {0.2987856f, 0.5871095f, 0.1141049f, 0.0f};
+    private static final float[] uCoeffs =
+        new float[] {-0.168805420f, -0.3317003f, 0.5005057f, 0.5f};
+    private static final float[] vCoeffs = new float[] {0.4997964f, -0.4184672f, -0.0813292f, 0.5f};
+
+    private int xUnitLoc;
+    private int coeffsLoc;
+
+    private float[] coeffs;
+    private float stepSize;
+
+    public void setPlaneY() {
+      coeffs = yCoeffs;
+      stepSize = 1.0f;
+    }
+
+    public void setPlaneU() {
+      coeffs = uCoeffs;
+      stepSize = 2.0f;
+    }
+
+    public void setPlaneV() {
+      coeffs = vCoeffs;
+      stepSize = 2.0f;
+    }
+
+    @Override
+    public void onNewShader(GlShader shader) {
+      xUnitLoc = shader.getUniformLocation("xUnit");
+      coeffsLoc = shader.getUniformLocation("coeffs");
+    }
+
+    @Override
+    public void onPrepareShader(GlShader shader, float[] texMatrix, int frameWidth, int frameHeight,
+        int viewportWidth, int viewportHeight) {
+      GLES20.glUniform4fv(coeffsLoc, /* count= */ 1, coeffs, /* offset= */ 0);
+      // Matrix * (1;0;0;0) / (width / stepSize). Note that OpenGL uses column major order.
+      GLES20.glUniform2f(
+          xUnitLoc, stepSize * texMatrix[0] / frameWidth, stepSize * texMatrix[1] / frameWidth);
+    }
+  }
 
   private final ThreadUtils.ThreadChecker threadChecker = new ThreadUtils.ThreadChecker();
-  private final GlTextureFrameBuffer textureFrameBuffer = new GlTextureFrameBuffer(GLES20.GL_RGBA);
-  private TextureBuffer.Type shaderTextureType;
-  private GlShader shader;
-  private int texMatrixLoc;
-  private int xUnitLoc;
-  private int coeffsLoc;
+  private final GlTextureFrameBuffer i420TextureFrameBuffer =
+      new GlTextureFrameBuffer(GLES20.GL_RGBA);
   private boolean released = false;
+  private final ShaderCallbacks shaderCallbacks = new ShaderCallbacks();
+  private final GlGenericDrawer drawer = new GlGenericDrawer(FRAGMENT_SHADER, shaderCallbacks);
 
   /**
    * This class should be constructed on a thread that has an active EGL context.
@@ -127,96 +105,11 @@
   }
 
   /** Converts the texture buffer to I420. */
-  public I420Buffer convert(TextureBuffer textureBuffer) {
-    final int width = textureBuffer.getWidth();
-    final int height = textureBuffer.getHeight();
-
-    // SurfaceTextureHelper requires a stride that is divisible by 8.  Round width up.
-    // See SurfaceTextureHelper for details on the size and format.
-    final int stride = ((width + 7) / 8) * 8;
-    final int uvHeight = (height + 1) / 2;
-    // Due to the layout used by SurfaceTextureHelper, vPos + stride * uvHeight would overrun the
-    // buffer.  Add one row at the bottom to compensate for this.  There will never be data in the
-    // extra row, but now other code does not have to deal with v stride * v height exceeding the
-    // buffer's capacity.
-    final int size = stride * (height + uvHeight + 1);
-    ByteBuffer buffer = JniCommon.nativeAllocateByteBuffer(size);
-    convert(buffer, width, height, stride, textureBuffer.getTextureId(),
-        RendererCommon.convertMatrixFromAndroidGraphicsMatrix(textureBuffer.getTransformMatrix()),
-        textureBuffer.getType());
-
-    final int yPos = 0;
-    final int uPos = yPos + stride * height;
-    // Rows of U and V alternate in the buffer, so V data starts after the first row of U.
-    final int vPos = uPos + stride / 2;
-
-    buffer.position(yPos);
-    buffer.limit(yPos + stride * height);
-    ByteBuffer dataY = buffer.slice();
-
-    buffer.position(uPos);
-    buffer.limit(uPos + stride * uvHeight);
-    ByteBuffer dataU = buffer.slice();
-
-    buffer.position(vPos);
-    buffer.limit(vPos + stride * uvHeight);
-    ByteBuffer dataV = buffer.slice();
-
-    // SurfaceTextureHelper uses the same stride for Y, U, and V data.
-    return JavaI420Buffer.wrap(width, height, dataY, stride, dataU, stride, dataV, stride,
-        () -> { JniCommon.nativeFreeByteBuffer(buffer); });
-  }
-
-  /** Deprecated, use convert(TextureBuffer). */
-  @Deprecated
-  void convert(ByteBuffer buf, int width, int height, int stride, int srcTextureId,
-      float[] transformMatrix) {
-    convert(buf, width, height, stride, srcTextureId, transformMatrix, TextureBuffer.Type.OES);
-  }
-
-  private void initShader(TextureBuffer.Type textureType) {
-    if (shader != null) {
-      shader.release();
-    }
-
-    final String fragmentShader;
-    switch (textureType) {
-      case OES:
-        fragmentShader = OES_FRAGMENT_SHADER;
-        break;
-      case RGB:
-        fragmentShader = RGB_FRAGMENT_SHADER;
-        break;
-      default:
-        throw new IllegalArgumentException("Unsupported texture type.");
-    }
-
-    shaderTextureType = textureType;
-    shader = new GlShader(VERTEX_SHADER, fragmentShader);
-    shader.useProgram();
-    texMatrixLoc = shader.getUniformLocation("texMatrix");
-    xUnitLoc = shader.getUniformLocation("xUnit");
-    coeffsLoc = shader.getUniformLocation("coeffs");
-    GLES20.glUniform1i(shader.getUniformLocation("tex"), 0);
-    GlUtil.checkNoGLES2Error("Initialize fragment shader uniform values.");
-    // Initialize vertex shader attributes.
-    shader.setVertexAttribArray("in_pos", 2, DEVICE_RECTANGLE);
-    // If the width is not a multiple of 4 pixels, the texture
-    // will be scaled up slightly and clipped at the right border.
-    shader.setVertexAttribArray("in_tc", 2, TEXTURE_RECTANGLE);
-  }
-
-  private void convert(ByteBuffer buf, int width, int height, int stride, int srcTextureId,
-      float[] transformMatrix, TextureBuffer.Type textureType) {
+  public I420Buffer convert(TextureBuffer inputTextureBuffer) {
     threadChecker.checkIsOnValidThread();
     if (released) {
       throw new IllegalStateException("YuvConverter.convert called on released object");
     }
-    if (textureType != shaderTextureType) {
-      initShader(textureType);
-    }
-    shader.useProgram();
-
     // We draw into a buffer laid out like
     //
     //    +---------+
@@ -245,83 +138,83 @@
     // Since the V data needs to start on a boundary of such a
     // larger pixel, it is not sufficient that |stride| is even, it
     // has to be a multiple of 8 pixels.
+    final int frameWidth = inputTextureBuffer.getWidth();
+    final int frameHeight = inputTextureBuffer.getHeight();
+    final int stride = ((frameWidth + 7) / 8) * 8;
+    final int uvHeight = (frameHeight + 1) / 2;
+    // Total height of the combined memory layout.
+    final int totalHeight = frameHeight + uvHeight;
+    final ByteBuffer i420ByteBuffer = JniCommon.nativeAllocateByteBuffer(stride * totalHeight);
+    // Viewport width is divided by four since we are squeezing in four color bytes in each RGBA
+    // pixel.
+    final int viewportWidth = stride / 4;
 
-    if (stride % 8 != 0) {
-      throw new IllegalArgumentException("Invalid stride, must be a multiple of 8");
-    }
-    if (stride < width) {
-      throw new IllegalArgumentException("Invalid stride, must >= width");
-    }
+    // Produce a frame buffer starting at top-left corner, not bottom-left.
+    final Matrix renderMatrix = new Matrix();
+    renderMatrix.preTranslate(0.5f, 0.5f);
+    renderMatrix.preScale(1f, -1f);
+    renderMatrix.preTranslate(-0.5f, -0.5f);
 
-    int y_width = (width + 3) / 4;
-    int uv_width = (width + 7) / 8;
-    int uv_height = (height + 1) / 2;
-    int total_height = height + uv_height;
-    int size = stride * total_height;
-
-    if (buf.capacity() < size) {
-      throw new IllegalArgumentException("YuvConverter.convert called with too small buffer");
-    }
-    // Produce a frame buffer starting at top-left corner, not
-    // bottom-left.
-    transformMatrix =
-        RendererCommon.multiplyMatrices(transformMatrix, RendererCommon.verticalFlipMatrix());
-
-    final int frameBufferWidth = stride / 4;
-    final int frameBufferHeight = total_height;
-    textureFrameBuffer.setSize(frameBufferWidth, frameBufferHeight);
+    i420TextureFrameBuffer.setSize(viewportWidth, totalHeight);
 
     // Bind our framebuffer.
-    GLES20.glBindFramebuffer(GLES20.GL_FRAMEBUFFER, textureFrameBuffer.getFrameBufferId());
+    GLES20.glBindFramebuffer(GLES20.GL_FRAMEBUFFER, i420TextureFrameBuffer.getFrameBufferId());
     GlUtil.checkNoGLES2Error("glBindFramebuffer");
 
-    GLES20.glActiveTexture(GLES20.GL_TEXTURE0);
-    GLES20.glBindTexture(textureType.getGlTarget(), srcTextureId);
-    GLES20.glUniformMatrix4fv(texMatrixLoc, 1, false, transformMatrix, 0);
+    // Draw Y.
+    shaderCallbacks.setPlaneY();
+    VideoFrameDrawer.drawTexture(drawer, inputTextureBuffer, renderMatrix, frameWidth, frameHeight,
+        /* viewportX= */ 0, /* viewportY= */ 0, viewportWidth,
+        /* viewportHeight= */ frameHeight);
 
-    // Draw Y
-    GLES20.glViewport(0, 0, y_width, height);
-    // Matrix * (1;0;0;0) / width. Note that opengl uses column major order.
-    GLES20.glUniform2f(xUnitLoc, transformMatrix[0] / width, transformMatrix[1] / width);
-    // Y'UV444 to RGB888, see
-    // https://en.wikipedia.org/wiki/YUV#Y.27UV444_to_RGB888_conversion.
-    // We use the ITU-R coefficients for U and V */
-    GLES20.glUniform4f(coeffsLoc, 0.2987856f, 0.5871095f, 0.1141049f, 0.0f);
-    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
+    // Draw U.
+    shaderCallbacks.setPlaneU();
+    VideoFrameDrawer.drawTexture(drawer, inputTextureBuffer, renderMatrix, frameWidth, frameHeight,
+        /* viewportX= */ 0, /* viewportY= */ frameHeight, viewportWidth / 2,
+        /* viewportHeight= */ uvHeight);
 
-    // Draw U
-    GLES20.glViewport(0, height, uv_width, uv_height);
-    // Matrix * (1;0;0;0) / (width / 2). Note that opengl uses column major order.
-    GLES20.glUniform2f(
-        xUnitLoc, 2.0f * transformMatrix[0] / width, 2.0f * transformMatrix[1] / width);
-    GLES20.glUniform4f(coeffsLoc, -0.168805420f, -0.3317003f, 0.5005057f, 0.5f);
-    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
+    // Draw V.
+    shaderCallbacks.setPlaneV();
+    VideoFrameDrawer.drawTexture(drawer, inputTextureBuffer, renderMatrix, frameWidth, frameHeight,
+        /* viewportX= */ viewportWidth / 2, /* viewportY= */ frameHeight, viewportWidth / 2,
+        /* viewportHeight= */ uvHeight);
 
-    // Draw V
-    GLES20.glViewport(stride / 8, height, uv_width, uv_height);
-    GLES20.glUniform4f(coeffsLoc, 0.4997964f, -0.4184672f, -0.0813292f, 0.5f);
-    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
-
-    GLES20.glReadPixels(
-        0, 0, frameBufferWidth, frameBufferHeight, GLES20.GL_RGBA, GLES20.GL_UNSIGNED_BYTE, buf);
+    GLES20.glReadPixels(0, 0, i420TextureFrameBuffer.getWidth(), i420TextureFrameBuffer.getHeight(),
+        GLES20.GL_RGBA, GLES20.GL_UNSIGNED_BYTE, i420ByteBuffer);
 
     GlUtil.checkNoGLES2Error("YuvConverter.convert");
 
     // Restore normal framebuffer.
     GLES20.glBindFramebuffer(GLES20.GL_FRAMEBUFFER, 0);
-    GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, 0);
 
-    // Unbind texture. Reportedly needed on some devices to get
-    // the texture updated from the camera.
-    GLES20.glBindTexture(textureType.getGlTarget(), 0);
+    // Prepare Y, U, and V ByteBuffer slices.
+    final int yPos = 0;
+    final int uPos = yPos + stride * frameHeight;
+    // Rows of U and V alternate in the buffer, so V data starts after the first row of U.
+    final int vPos = uPos + stride / 2;
+
+    i420ByteBuffer.position(yPos);
+    i420ByteBuffer.limit(yPos + stride * frameHeight);
+    final ByteBuffer dataY = i420ByteBuffer.slice();
+
+    i420ByteBuffer.position(uPos);
+    // The last row does not have padding.
+    final int uvSize = stride * (uvHeight - 1) + stride / 2;
+    i420ByteBuffer.limit(uPos + uvSize);
+    final ByteBuffer dataU = i420ByteBuffer.slice();
+
+    i420ByteBuffer.position(vPos);
+    i420ByteBuffer.limit(vPos + uvSize);
+    final ByteBuffer dataV = i420ByteBuffer.slice();
+
+    return JavaI420Buffer.wrap(frameWidth, frameHeight, dataY, stride, dataU, stride, dataV, stride,
+        () -> { JniCommon.nativeFreeByteBuffer(i420ByteBuffer); });
   }
 
   public void release() {
     threadChecker.checkIsOnValidThread();
     released = true;
-    if (shader != null) {
-      shader.release();
-    }
-    textureFrameBuffer.release();
+    drawer.release();
+    i420TextureFrameBuffer.release();
   }
 }
diff --git a/sdk/android/src/java/org/webrtc/GlGenericDrawer.java b/sdk/android/src/java/org/webrtc/GlGenericDrawer.java
new file mode 100644
index 0000000..cd7b1e8
--- /dev/null
+++ b/sdk/android/src/java/org/webrtc/GlGenericDrawer.java
@@ -0,0 +1,279 @@
+/*
+ *  Copyright 2018 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+package org.webrtc;
+
+import android.opengl.GLES11Ext;
+import android.opengl.GLES20;
+import java.nio.FloatBuffer;
+import javax.annotation.Nullable;
+import org.webrtc.GlShader;
+import org.webrtc.GlUtil;
+import org.webrtc.RendererCommon;
+
+/**
+ * Helper class to implement an instance of RendererCommon.GlDrawer that can accept multiple input
+ * sources (OES, RGB, or YUV) using a generic fragment shader as input. The generic fragment shader
+ * should sample pixel values from the function "sample" that will be provided by this class and
+ * provides an abstraction for the input source type (OES, RGB, or YUV). The texture coordinate
+ * variable name will be "tc" and the texture matrix in the vertex shader will be "tex_mat". The
+ * simplest possible generic shader that just draws pixel from the frame unmodified looks like:
+ * void main() {
+ *   gl_FragColor = sample(tc);
+ * }
+ * This class covers the cases for most simple shaders and generates the necessary boiler plate.
+ * Advanced shaders can always implement RendererCommon.GlDrawer directly.
+ */
+class GlGenericDrawer implements RendererCommon.GlDrawer {
+  /**
+   * The different shader types representing different input sources. YUV here represents three
+   * separate Y, U, V textures.
+   */
+  public static enum ShaderType { OES, RGB, YUV }
+
+  /**
+   * The shader callbacks is used to customize behavior for a GlDrawer. It provides a hook to set
+   * uniform variables in the shader before a frame is drawn.
+   */
+  public static interface ShaderCallbacks {
+    /**
+     * This callback is called when a new shader has been compiled and created. It will be called
+     * for the first frame as well as when the shader type is changed. This callback can be used to
+     * do custom initialization of the shader that only needs to happen once.
+     */
+    void onNewShader(GlShader shader);
+
+    /**
+     * This callback is called before rendering a frame. It can be used to do custom preparation of
+     * the shader that needs to happen every frame.
+     */
+    void onPrepareShader(GlShader shader, float[] texMatrix, int frameWidth, int frameHeight,
+        int viewportWidth, int viewportHeight);
+  }
+
+  private static final String INPUT_VERTEX_COORDINATE_NAME = "in_pos";
+  private static final String INPUT_TEXTURE_COORDINATE_NAME = "in_tc";
+  private static final String TEXTURE_MATRIX_NAME = "tex_mat";
+  private static final String DEFAULT_VERTEX_SHADER_STRING = "varying vec2 tc;\n"
+      + "attribute vec4 in_pos;\n"
+      + "attribute vec4 in_tc;\n"
+      + "uniform mat4 tex_mat;\n"
+      + "void main() {\n"
+      + "  gl_Position = in_pos;\n"
+      + "  tc = (tex_mat * in_tc).xy;\n"
+      + "}\n";
+
+  // Vertex coordinates in Normalized Device Coordinates, i.e. (-1, -1) is bottom-left and (1, 1)
+  // is top-right.
+  private static final FloatBuffer FULL_RECTANGLE_BUFFER = GlUtil.createFloatBuffer(new float[] {
+      -1.0f, -1.0f, // Bottom left.
+      1.0f, -1.0f, // Bottom right.
+      -1.0f, 1.0f, // Top left.
+      1.0f, 1.0f, // Top right.
+  });
+
+  // Texture coordinates - (0, 0) is bottom-left and (1, 1) is top-right.
+  private static final FloatBuffer FULL_RECTANGLE_TEXTURE_BUFFER =
+      GlUtil.createFloatBuffer(new float[] {
+          0.0f, 0.0f, // Bottom left.
+          1.0f, 0.0f, // Bottom right.
+          0.0f, 1.0f, // Top left.
+          1.0f, 1.0f, // Top right.
+      });
+
+  static String createFragmentShaderString(String genericFragmentSource, ShaderType shaderType) {
+    final StringBuilder stringBuilder = new StringBuilder();
+    if (shaderType == ShaderType.OES) {
+      stringBuilder.append("#extension GL_OES_EGL_image_external : require\n");
+    }
+    stringBuilder.append("precision mediump float;\n");
+    stringBuilder.append("varying vec2 tc;\n");
+
+    if (shaderType == ShaderType.YUV) {
+      stringBuilder.append("uniform sampler2D y_tex;\n");
+      stringBuilder.append("uniform sampler2D u_tex;\n");
+      stringBuilder.append("uniform sampler2D v_tex;\n");
+
+      // Add separate function for sampling texture.
+      stringBuilder.append("vec4 sample(vec2 p) {\n");
+      stringBuilder.append("  float y = texture2D(y_tex, p).r;\n");
+      stringBuilder.append("  float u = texture2D(u_tex, p).r - 0.5;\n");
+      stringBuilder.append("  float v = texture2D(v_tex, p).r - 0.5;\n");
+      stringBuilder.append(
+          "  return vec4(y + 1.403 * v, y - 0.344 * u - 0.714 * v, y + 1.77 * u, 1);\n");
+      stringBuilder.append("}\n");
+      stringBuilder.append(genericFragmentSource);
+    } else {
+      final String samplerName = shaderType == ShaderType.OES ? "samplerExternalOES" : "sampler2D";
+      stringBuilder.append("uniform ").append(samplerName).append(" tex;\n");
+
+      // Update the sampling function in-place.
+      stringBuilder.append(genericFragmentSource.replace("sample(", "texture2D(tex, "));
+    }
+
+    return stringBuilder.toString();
+  }
+
+  private final String genericFragmentSource;
+  private final String vertexShader;
+  private final ShaderCallbacks shaderCallbacks;
+  @Nullable private ShaderType currentShaderType;
+  @Nullable private GlShader currentShader;
+  private int inPosLocation;
+  private int inTcLocation;
+  private int texMatrixLocation;
+
+  public GlGenericDrawer(String genericFragmentSource, ShaderCallbacks shaderCallbacks) {
+    this(DEFAULT_VERTEX_SHADER_STRING, genericFragmentSource, shaderCallbacks);
+  }
+
+  public GlGenericDrawer(
+      String vertexShader, String genericFragmentSource, ShaderCallbacks shaderCallbacks) {
+    this.vertexShader = vertexShader;
+    this.genericFragmentSource = genericFragmentSource;
+    this.shaderCallbacks = shaderCallbacks;
+  }
+
+  // Visible for testing.
+  GlShader createShader(ShaderType shaderType) {
+    return new GlShader(
+        vertexShader, createFragmentShaderString(genericFragmentSource, shaderType));
+  }
+
+  /**
+   * Draw an OES texture frame with specified texture transformation matrix. Required resources are
+   * allocated at the first call to this function.
+   */
+  @Override
+  public void drawOes(int oesTextureId, float[] texMatrix, int frameWidth, int frameHeight,
+      int viewportX, int viewportY, int viewportWidth, int viewportHeight) {
+    prepareShader(
+        ShaderType.OES, texMatrix, frameWidth, frameHeight, viewportWidth, viewportHeight);
+    // Bind the texture.
+    GLES20.glActiveTexture(GLES20.GL_TEXTURE0);
+    GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, oesTextureId);
+    // Draw the texture.
+    GLES20.glViewport(viewportX, viewportY, viewportWidth, viewportHeight);
+    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
+    // Unbind the texture as a precaution.
+    GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
+  }
+
+  /**
+   * Draw a RGB(A) texture frame with specified texture transformation matrix. Required resources
+   * are allocated at the first call to this function.
+   */
+  @Override
+  public void drawRgb(int textureId, float[] texMatrix, int frameWidth, int frameHeight,
+      int viewportX, int viewportY, int viewportWidth, int viewportHeight) {
+    prepareShader(
+        ShaderType.RGB, texMatrix, frameWidth, frameHeight, viewportWidth, viewportHeight);
+    // Bind the texture.
+    GLES20.glActiveTexture(GLES20.GL_TEXTURE0);
+    GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, textureId);
+    // Draw the texture.
+    GLES20.glViewport(viewportX, viewportY, viewportWidth, viewportHeight);
+    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
+    // Unbind the texture as a precaution.
+    GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, 0);
+  }
+
+  /**
+   * Draw a YUV frame with specified texture transformation matrix. Required resources are allocated
+   * at the first call to this function.
+   */
+  @Override
+  public void drawYuv(int[] yuvTextures, float[] texMatrix, int frameWidth, int frameHeight,
+      int viewportX, int viewportY, int viewportWidth, int viewportHeight) {
+    prepareShader(
+        ShaderType.YUV, texMatrix, frameWidth, frameHeight, viewportWidth, viewportHeight);
+    // Bind the textures.
+    for (int i = 0; i < 3; ++i) {
+      GLES20.glActiveTexture(GLES20.GL_TEXTURE0 + i);
+      GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, yuvTextures[i]);
+    }
+    // Draw the textures.
+    GLES20.glViewport(viewportX, viewportY, viewportWidth, viewportHeight);
+    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
+    // Unbind the textures as a precaution.
+    for (int i = 0; i < 3; ++i) {
+      GLES20.glActiveTexture(GLES20.GL_TEXTURE0 + i);
+      GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, 0);
+    }
+  }
+
+  private void prepareShader(ShaderType shaderType, float[] texMatrix, int frameWidth,
+      int frameHeight, int viewportWidth, int viewportHeight) {
+    final GlShader shader;
+    if (shaderType.equals(currentShaderType)) {
+      // Same shader type as before, reuse exising shader.
+      shader = currentShader;
+    } else {
+      // Allocate new shader.
+      currentShaderType = shaderType;
+      if (currentShader != null) {
+        currentShader.release();
+      }
+      shader = createShader(shaderType);
+      currentShader = shader;
+
+      shader.useProgram();
+      // Set input texture units.
+      if (shaderType == ShaderType.YUV) {
+        GLES20.glUniform1i(shader.getUniformLocation("y_tex"), 0);
+        GLES20.glUniform1i(shader.getUniformLocation("u_tex"), 1);
+        GLES20.glUniform1i(shader.getUniformLocation("v_tex"), 2);
+      } else {
+        GLES20.glUniform1i(shader.getUniformLocation("tex"), 0);
+      }
+
+      GlUtil.checkNoGLES2Error("Create shader");
+      shaderCallbacks.onNewShader(shader);
+      texMatrixLocation = shader.getUniformLocation(TEXTURE_MATRIX_NAME);
+      inPosLocation = shader.getAttribLocation(INPUT_VERTEX_COORDINATE_NAME);
+      inTcLocation = shader.getAttribLocation(INPUT_TEXTURE_COORDINATE_NAME);
+    }
+
+    shader.useProgram();
+
+    // Upload the vertex coordinates.
+    GLES20.glEnableVertexAttribArray(inPosLocation);
+    GLES20.glVertexAttribPointer(inPosLocation, /* size= */ 2,
+        /* type= */ GLES20.GL_FLOAT, /* normalized= */ false, /* stride= */ 0,
+        FULL_RECTANGLE_BUFFER);
+
+    // Upload the texture coordinates.
+    GLES20.glEnableVertexAttribArray(inTcLocation);
+    GLES20.glVertexAttribPointer(inTcLocation, /* size= */ 2,
+        /* type= */ GLES20.GL_FLOAT, /* normalized= */ false, /* stride= */ 0,
+        FULL_RECTANGLE_TEXTURE_BUFFER);
+
+    // Upload the texture transformation matrix.
+    GLES20.glUniformMatrix4fv(
+        texMatrixLocation, 1 /* count= */, false /* transpose= */, texMatrix, 0 /* offset= */);
+
+    // Do custom per-frame shader preparation.
+    shaderCallbacks.onPrepareShader(
+        shader, texMatrix, frameWidth, frameHeight, viewportWidth, viewportHeight);
+    GlUtil.checkNoGLES2Error("Prepare shader");
+  }
+
+  /**
+   * Release all GLES resources. This needs to be done manually, otherwise the resources are leaked.
+   */
+  @Override
+  public void release() {
+    if (currentShader != null) {
+      currentShader.release();
+      currentShader = null;
+      currentShaderType = null;
+    }
+  }
+}
diff --git a/sdk/android/tests/src/org/webrtc/GlGenericDrawerTest.java b/sdk/android/tests/src/org/webrtc/GlGenericDrawerTest.java
new file mode 100644
index 0000000..06b5367
--- /dev/null
+++ b/sdk/android/tests/src/org/webrtc/GlGenericDrawerTest.java
@@ -0,0 +1,158 @@
+/*
+ *  Copyright 2018 The WebRTC Project Authors. All rights reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+package org.webrtc;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+
+import org.chromium.testing.local.LocalRobolectricTestRunner;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.robolectric.annotation.Config;
+import org.webrtc.GlShader;
+
+@RunWith(LocalRobolectricTestRunner.class)
+@Config(manifest = Config.NONE)
+public class GlGenericDrawerTest {
+  // Simplest possible valid generic fragment shader.
+  private static final String FRAGMENT_SHADER = "void main() {\n"
+      + "  gl_FragColor = sample(tc);\n"
+      + "}\n";
+  private static final int TEXTURE_ID = 3;
+  private static final float[] TEX_MATRIX =
+      new float[] {1, 2, 3, 4, -1, -2, -3, -4, 0, 0, 1, 0, 0, 0, 0, 1};
+  private static final int FRAME_WIDTH = 640;
+  private static final int FRAME_HEIGHT = 480;
+  private static final int VIEWPORT_X = 3;
+  private static final int VIEWPORT_Y = 5;
+  private static final int VIEWPORT_WIDTH = 500;
+  private static final int VIEWPORT_HEIGHT = 500;
+
+  // Replace OpenGLES GlShader dependency with a mock.
+  private class GlGenericDrawerForTest extends GlGenericDrawer {
+    public GlGenericDrawerForTest(String genericFragmentSource, ShaderCallbacks shaderCallbacks) {
+      super(genericFragmentSource, shaderCallbacks);
+    }
+
+    @Override
+    GlShader createShader(ShaderType shaderType) {
+      return mockedShader;
+    }
+  }
+
+  private GlShader mockedShader;
+  private GlGenericDrawer glGenericDrawer;
+  private GlGenericDrawer.ShaderCallbacks mockedCallbacks;
+
+  @Before
+  public void setUp() {
+    mockedShader = mock(GlShader.class);
+    mockedCallbacks = mock(GlGenericDrawer.ShaderCallbacks.class);
+    glGenericDrawer = new GlGenericDrawerForTest(FRAGMENT_SHADER, mockedCallbacks);
+  }
+
+  @After
+  public void tearDown() {
+    verifyNoMoreInteractions(mockedCallbacks);
+  }
+
+  @Test
+  public void testOesFragmentShader() {
+    final String expectedOesFragmentShader = "#extension GL_OES_EGL_image_external : require\n"
+        + "precision mediump float;\n"
+        + "varying vec2 tc;\n"
+        + "uniform samplerExternalOES tex;\n"
+        + "void main() {\n"
+        + "  gl_FragColor = texture2D(tex, tc);\n"
+        + "}\n";
+    final String oesFragmentShader =
+        GlGenericDrawer.createFragmentShaderString(FRAGMENT_SHADER, GlGenericDrawer.ShaderType.OES);
+    assertEquals(expectedOesFragmentShader, oesFragmentShader);
+  }
+
+  @Test
+  public void testRgbFragmentShader() {
+    final String expectedRgbFragmentShader = "precision mediump float;\n"
+        + "varying vec2 tc;\n"
+        + "uniform sampler2D tex;\n"
+        + "void main() {\n"
+        + "  gl_FragColor = texture2D(tex, tc);\n"
+        + "}\n";
+    final String rgbFragmentShader =
+        GlGenericDrawer.createFragmentShaderString(FRAGMENT_SHADER, GlGenericDrawer.ShaderType.RGB);
+    assertEquals(expectedRgbFragmentShader, rgbFragmentShader);
+  }
+
+  @Test
+  public void testYuvFragmentShader() {
+    final String expectedYuvFragmentShader = "precision mediump float;\n"
+        + "varying vec2 tc;\n"
+        + "uniform sampler2D y_tex;\n"
+        + "uniform sampler2D u_tex;\n"
+        + "uniform sampler2D v_tex;\n"
+        + "vec4 sample(vec2 p) {\n"
+        + "  float y = texture2D(y_tex, p).r;\n"
+        + "  float u = texture2D(u_tex, p).r - 0.5;\n"
+        + "  float v = texture2D(v_tex, p).r - 0.5;\n"
+        + "  return vec4(y + 1.403 * v, y - 0.344 * u - 0.714 * v, y + 1.77 * u, 1);\n"
+        + "}\n"
+        + "void main() {\n"
+        + "  gl_FragColor = sample(tc);\n"
+        + "}\n";
+    final String yuvFragmentShader =
+        GlGenericDrawer.createFragmentShaderString(FRAGMENT_SHADER, GlGenericDrawer.ShaderType.YUV);
+    assertEquals(expectedYuvFragmentShader, yuvFragmentShader);
+  }
+
+  @Test
+  public void testShaderCallbacksOneRgbFrame() {
+    glGenericDrawer.drawRgb(TEXTURE_ID, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_X,
+        VIEWPORT_Y, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+
+    verify(mockedCallbacks).onNewShader(mockedShader);
+    verify(mockedCallbacks)
+        .onPrepareShader(
+            mockedShader, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+  }
+
+  @Test
+  public void testShaderCallbacksTwoRgbFrames() {
+    glGenericDrawer.drawRgb(TEXTURE_ID, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_X,
+        VIEWPORT_Y, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+    glGenericDrawer.drawRgb(TEXTURE_ID, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_X,
+        VIEWPORT_Y, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+
+    // Expect only one shader to be created, but two frames to be drawn.
+    verify(mockedCallbacks, times(1)).onNewShader(mockedShader);
+    verify(mockedCallbacks, times(2))
+        .onPrepareShader(
+            mockedShader, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+  }
+
+  @Test
+  public void testShaderCallbacksChangingShaderType() {
+    glGenericDrawer.drawRgb(TEXTURE_ID, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_X,
+        VIEWPORT_Y, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+    glGenericDrawer.drawOes(TEXTURE_ID, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_X,
+        VIEWPORT_Y, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+
+    // Expect two shaders to be created, and two frames to be drawn.
+    verify(mockedCallbacks, times(2)).onNewShader(mockedShader);
+    verify(mockedCallbacks, times(2))
+        .onPrepareShader(
+            mockedShader, TEX_MATRIX, FRAME_WIDTH, FRAME_HEIGHT, VIEWPORT_WIDTH, VIEWPORT_HEIGHT);
+  }
+}